{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 量子分类器\n",
    "\n",
    "<em> Copyright (c) 2021 Institute for Quantum Computing, Baidu Inc. All Rights Reserved. </em>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概览\n",
    "\n",
    "本教程我们将讨论量子分类器（quantum classifier）的原理，以及如何利用量子神经网络（quantum neural network, QNN）来完成**两分类**任务。这类方法早期工作的主要代表是 Mitarai et al.(2018) 的量子电路学习 [(Quantum Circuit Learning, QCL)](https://arxiv.org/abs/1803.00745) [1], Farhi & Neven (2018) [2] 和 Schuld et al.(2018) 的中心电路量子分类器 [Circuit-Centric Quantum Classifiers](https://arxiv.org/abs/1804.00633) [3]。这里我们以第一类的 QCL 框架应用于监督学习（Supervised learning）为例进行介绍，通常我们需要先将经典数据编码成量子数据，然后通过训练量子神经网络的参数，最终得到一个最优的分类器。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 背景\n",
    "\n",
    "在监督学习的情况下，我们需要输入 $N$ 组带标签的数据点构成的数据集 $D = \\{(x^k,y^k)\\}_{k=1}^{N}$，其中 $x^k\\in \\mathbb{R}^{m}$ 是数据点，$y^k \\in\\{0,1\\}$ 是对应数据点 $x^k$ 的分类标签。**分类过程实质上是一个决策过程，决策给定数据点的标签归属问题**。 对于量子分类器框架，分类器 $\\mathcal{F}$ 的实现方式为一个含参 $\\theta$ 的量子神经网络/参数化量子电路, 测量量子系统以及数据后处理的组合。一个优秀的分类器 $\\mathcal{F}_\\theta$ 应该尽可能的将每个数据集内的数据点正确地映射到相对应的标签上 $\\mathcal{F}_\\theta(x^k) \\rightarrow y^k$。因此，我们将预测标签 $\\tilde{y}^{k} = \\mathcal{F}_\\theta(x^k)$ 和实际标签 $y^k$ 之间的累计距离作为损失函数 $\\mathcal{L}(\\theta)$ 进行优化。对于两分类任务，可以选择二次损失函数\n",
    "\n",
    "$$\n",
    "\\mathcal{L}(\\theta) = \\sum_{k=1}^N |\\tilde{y}^{k}-y^k|^2. \\tag{1}\n",
    "$$\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 方案流程\n",
    "\n",
    "这里我们给出实现量子电路学习 (QCL) 框架下量子分类器的一个流程。\n",
    "\n",
    "1. 在初始化的量子比特 $\\lvert 0 \\rangle$ 上作用参数化的酉门 $U$（unitary gate），从而把原始的经典数据点 $x^k$ 编码成量子计算机可以运行的量子数据 $\\lvert \\psi_{in}\\rangle^k$。\n",
    "2. 使输入态 $\\lvert \\psi_{in} \\rangle^k$ 通过参数为 $\\theta$ 的参数化电路 $U(\\theta)$ ，由此获得输出态 $\\lvert \\psi_{out}\\rangle^k = U(\\theta)\\lvert \\psi_{in} \\rangle^k$。\n",
    "3. 对量子神经网络处理后的量子态 $\\lvert \\psi_{out}\\rangle^k$ 进行测量和数据后处理，得到估计出的标签 $\\tilde{y}^{k}$。\n",
    "4. 重复步骤2-3直到数据集内所有的数据点都经过了处理。然后计算损失函数 $\\mathcal{L}(\\theta)$。\n",
    "5. 通过梯度下降等优化方法不断调整参数 $\\theta$ 的值，从而最小化损失函数。记录优化完成后的最优参数 $\\theta^*$, 这时我们就学习到了最优的分类器 $\\mathcal{F}_{\\theta^*}$。\n",
    "\n",
    "\n",
    "\n",
    "![QCL](figures/qclassifier-fig-pipeline-cn.png \"图 1：量子分类器训练的流程图\")\n",
    "<div style=\"text-align:center\">图 1：量子分类器训练的流程图 </div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Paddle Quantum 实现\n",
    "\n",
    "这里，我们先导入所需要的语言包：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:03.419838Z",
     "start_time": "2021-03-02T09:15:03.413324Z"
    }
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import matplotlib\n",
    "import numpy as np\n",
    "import paddle\n",
    "from numpy import pi as PI\n",
    "from matplotlib import pyplot as plt\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from paddle import matmul, transpose\n",
    "from paddle_quantum.circuit import UAnsatz\n",
    "from paddle_quantum.utils import pauli_str_to_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:03.845958Z",
     "start_time": "2021-03-02T09:15:03.840512Z"
    }
   },
   "outputs": [],
   "source": [
    "# 这是教程中会用到的几个主要函数\n",
    "__all__ = [\n",
    "    \"circle_data_point_generator\",\n",
    "    \"data_point_plot\",\n",
    "    \"heatmap_plot\",\n",
    "    \"Ry\",\n",
    "    \"Rz\",\n",
    "    \"Observable\",\n",
    "    \"U_theta\",\n",
    "    \"Net\",\n",
    "    \"QC\",\n",
    "    \"main\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据集的生成\n",
    "\n",
    "对于监督学习来说，我们绕不开的一个问题就是——采用的数据集是什么样的？在这个教程中我们按照论文 [1] 里所提及方法生成简单的圆形决策边界二分数据集 $\\{(x^{k}, y^{k})\\}$。其中数据点 $x^{k}\\in \\mathbb{R}^{2}$，标签 $y^{k} \\in \\{0,1\\}$。\n",
    "\n",
    "![数据集](figures/qclassifier-fig-data-cn.png \"图 2：生成的数据集和对应的决策边界\")\n",
    "<div style=\"text-align:center\">图 2：生成的数据集和对应的决策边界 </div>\n",
    "\n",
    "具体的生成方式和可视化请见如下代码："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:04.631031Z",
     "start_time": "2021-03-02T09:15:04.617301Z"
    }
   },
   "outputs": [],
   "source": [
    "# 圆形决策边界两分类数据集生成器\n",
    "def circle_data_point_generator(Ntrain, Ntest, boundary_gap, seed_data):\n",
    "    \"\"\"\n",
    "    :param Ntrain: 训练集大小\n",
    "    :param Ntest: 测试集大小\n",
    "    :param boundary_gap: 取值于 (0, 0.5), 两类别之间的差距\n",
    "    :param seed_data: 随机种子\n",
    "    :return: 'Ntrain' 训练集\n",
    "             'Ntest' 测试集\n",
    "    \"\"\"\n",
    "    train_x, train_y = [], []\n",
    "    num_samples, seed_para = 0, 0\n",
    "    while num_samples < Ntrain + Ntest:\n",
    "        np.random.seed((seed_data + 10) * 1000 + seed_para + num_samples)\n",
    "        data_point = np.random.rand(2) * 2 - 1\n",
    "\n",
    "        # 如果数据点的模小于(0.7 - gap)，标为0\n",
    "        if np.linalg.norm(data_point) < 0.7 - boundary_gap / 2:\n",
    "            train_x.append(data_point)\n",
    "            train_y.append(0.)\n",
    "            num_samples += 1\n",
    "\n",
    "        # 如果数据点的模大于(0.7 + gap)，标为1\n",
    "        elif np.linalg.norm(data_point) > 0.7 + boundary_gap / 2:\n",
    "            train_x.append(data_point)\n",
    "            train_y.append(1.)\n",
    "            num_samples += 1\n",
    "        else:\n",
    "            seed_para += 1\n",
    "\n",
    "    train_x = np.array(train_x).astype(\"float64\")\n",
    "    train_y = np.array([train_y]).astype(\"float64\").T\n",
    "\n",
    "    print(\"训练集的维度大小 x {} 和 y {}\".format(np.shape(train_x[0:Ntrain]), np.shape(train_y[0:Ntrain])))\n",
    "    print(\"测试集的维度大小 x {} 和 y {}\".format(np.shape(train_x[Ntrain:]), np.shape(train_y[Ntrain:])), \"\\n\")\n",
    "\n",
    "    return train_x[0:Ntrain], train_y[0:Ntrain], train_x[Ntrain:], train_y[Ntrain:]\n",
    "\n",
    "\n",
    "# 用以可视化生成的数据集\n",
    "def data_point_plot(data, label):\n",
    "    \"\"\"\n",
    "    :param data: 形状为 [M, 2], 代表 M 2-D 数据点\n",
    "    :param label: 取值 0 或者 1\n",
    "    :return: 画这些数据点\n",
    "    \"\"\"\n",
    "    dim_samples, dim_useless = np.shape(data)\n",
    "    plt.figure(1)\n",
    "    for i in range(dim_samples):\n",
    "        if label[i] == 0:\n",
    "            plt.plot(data[i][0], data[i][1], color=\"r\", marker=\"o\")\n",
    "        elif label[i] == 1:\n",
    "            plt.plot(data[i][0], data[i][1], color=\"b\", marker=\"o\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:06.422981Z",
     "start_time": "2021-03-02T09:15:05.043595Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集的维度大小 x (200, 2) 和 y (200, 1)\n",
      "测试集的维度大小 x (100, 2) 和 y (100, 1) \n",
      "\n",
      "训练集 200 个数据点的可视化：\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAoeklEQVR4nO2df6weV5nfP49fY2kNlCbXJjg/fJ1IESJZAU2uIgIrlnQJm3hLA9UiQZ3Uu0W1fAXSsupu5chSFRVZZam620BJu27Wkpt7C6pU2ETUaQjZRduCFriJEsdp6vwiP4xT4jiwLAIV4pz+MfNyx6/n98yZOfPO9yMdvfN7nnfmzHnOOc9znmPOOYQQQoyXDX0LIIQQol+kCIQQYuRIEQghxMiRIhBCiJEjRSCEECNnY98C1GHLli1ux44dfYshhBCD4sEHH3zZObd1dvsgFcGOHTtYW1vrWwwhhBgUZvZc2nZ1DQkhxMiRIhBCiJEjRSCEECNHikAIIUaOFIEQQoycVhSBmR0ys5fM7FjGfjOzz5nZU2Z21MyuSuy7wcyOx/v2tSGPEHVYXYUdO2DDhuh3dbVvicIk9OcUinyhyFEK51zjBLwXuAo4lrF/J3AvYMC7gG/H2yfA08BlwCbgEeCKovtdffXVbsisrDi3uOicWfS7stK3RGJlxbnNm52D9bR5s97NLL6eU1vfRCjvMRQ5ZgHWXFoZnbaxTgJ25CiCPwU+llg/DmwDrgXuS2y/Fbi16F5DVgRpGQScW1joP5P4JHTlt7h47juBaLtYx8dzarPQDOU9hiLHLFmKoCsbwUXAC4n1E/G2rO3nYGZ7zGzNzNZOnTrlTVDf7N8PP/3pudtPn4Y9e9ptPobSNF1djf7bc89Fn8Nzz7X/X5vy/PPVto+V51KHIzV7TmnfxE9/Gm2vSijvMRQ5ytKVIrCUbS5n+7kbnTvonFtyzi1t3XrOCOnBkPUhQf3Mn0ZIhW+bH7ovtm+vtt03oSjxJKurYGlfLPWe0/Q/tqlcQnmPochRlq4UwQngksT6xcDJnO1zS9aHNKWtGkNIhe8QakcHDsDmzWdv27w52t41ISnxJPv3R/LMYlb9OSX/YxZ1Cs2677GJ4k07N6T8VIq0/qI6iXwbwW9xtrH4O/H2jcAzwKWsG4uvLLpXGzaCvvqs0/oN0/oQm8pnln59s5b/UAlC7S+dJRQ7hq/n5StPQXVZsv5jG4bV5WXnJpPoOpNJtJ5HExtF3rmh5Kck+DQWA18EXgR+QVTL/ziwF9gb7zfgC0QeQo8CS4lzdwJPxPv2l7lfU0XQp0W/TOZvQ76QCt9QPShCxYcSDy1P5SmVrr2GmvyvkL6zMnhVBF2npoog6+VNJv6198JC+r03bFi/ZxuZq6/CN6sWFGLtKFR8FC6h5SlfBWid65ZRvFn5t8m5fSBFkCCvNuK70FxZcW7TprPvtWnT2fdqq0bYdQZUzb8dfDzH0PKUr7xS538WKY88WZuc2wdSBAmK+id9N++KPqahNTenDFXuEGlbiYf4bnxUVOr8z6LCOu+aTc7tAymCBFmDutJSH4RWiyhLSAZqcTZt56lQujtm5Vhervc/8/5PUb5ucm7XSBHMsLKy7lmQlSaTxrdpJF8IH1oVQqv9iLMJvVunLTmWl8NpTYX2TUgRpFDGViDKE0oBIfwSSuHWlRy+3Ev7IEsRjDoMddGAlcXFbuToCt+jVXftgoMHo+dmFv0ePBhtF/NDKAMEu5KjSb4ezDeRph1CT221CPJsBfNWkw2tZiKGy9haBHkMrQsXtQjOJamtASaT6DdYrd2AkEJOiGETSviEvuUINRRIHSxSEsNiaWnJra2t9S3GoNiwIcqss5jBa691L48YNqurUSXi+eejLtYDB/qpOPUpR1bAvMVFePbZbmSoipk96JxbOme7FME4GGKmFSJkhli5ylIEo+4ayiPEMMBN6LsZLcS80STUdGjlixRBCl30/VXJCG1kmsF4LwgxEJqEvA7OtpBmQQ49+Z6q0rc3QhUPHnn7CNE+dbx90s6pc50+vZ3QgLLy+B4WXiUjhOAiJ8Q8UaZy1VboijT6DDuRpQhkLE7Bt2G1ipFpiAYpIUKm6Puedt0k3a3N0r/DOmVCn44bXo3FZnaDmR03s6fMbF/K/j80s4fjdMzMzpjZ+fG+Z83s0XhfEK5Avg2rVYxM82SQEsIXVfJ60YjktDE3WfXlOqOYg3TcSGsmVEnAhGh2sctYn27yipzjPwj8RWL9WWBLlXv67hpyzu+IwS5sBLItiLFQNa8XdbeWiUHWtIu2v6lyPdkIgGuB+xLrtwK35hz/X4B/llgPUhH4pkpGGJpBSoguqZrX684hMKsghlix8qkIfhu4M7F+C/DvM47dDLwCnJ/Y9j3gIeBBYE+Ze86DIvBNaHHQhfBFnbyeV7kqCm8N6yHshxBfKEmWImjDRmBpPU4Zx34Q+KZz7pXEtvc4564CbgQ+YWbvTb2J2R4zWzOztVOnTjWTeAQ0sS0IMSTq5PVduyLD7GuvRb/J8TRZY27uuGO9f//MmejYIMYAtEAbiuAEcEli/WLgZMaxHwW+mNzgnDsZ/74EfAW4Ju1E59xB59ySc25p69atjYWeB/IMZEEapITwgI+8nqUo6gZv7HoAaWXSmglVErAReAa4lHVj8ZUpx72JqFvo9YltrwfemFj+FnBD0T3VNVTPF3pITVghqtBGXi9zjbrdUKEMIMXngDJgJ/AEkffQ/njbXmBv4pjfAb40c95lseJ4BHhsem5RkiJwbmEhPUPKGCxERFWHjDIFcB0njJAGkHpVBF2nsSuClZX0zCJjsBARbbuU1r2uc+VbEV1811mKQEHncgh1QFZef6SMwUJU78vPG2SWLAf274fdu6sFbyxjzJ6OZq56jdZI0w6hp64GlIU6ICtvwEsI8gnRN1X78rNaBAsLzcuBtLJk06bo2tNuq6yu3q5sBGoRZBDy1I5ZtYOFBYWVFgKqu5RmeR5B83Jg1h11YSEq4k+fjn6fey5azqKLcPFSBBkUxSPpk6xMe/vt/cgjRGhUdSnNGjvwyivpx1ctB5LuqG94A/ziF+XOW1zspnInRZBBXo2ib9uBJpkRIp8630ja2AEfAzPLKpFOx/2k9ReFnvq0EbQZl1wIETY+bIV59og0d9c2xwMhG0E1smoUR46EazsQQrSLj9Z3XtduskUCsGUL3Hyz/2ktpQhSmHb93HJLtH7XXetNxZBtB0KI9smLSwTVu4qzlAusX2fLFvjd3003IvuoeEoRzFA0sbSCuQkhptSdiH5WucDZ1zl9Ot+g3HbFc5SKIE+DF7mNthXgqm+DsxCiOW25maddJ4/WK55phoPQUxNjcZHxp8xAlKbGm5AHqwkhytPWvB9VZkVrUlagyesjiiaO7mJi6T4nrxZCtEdb33LWdWZZWIiMynWN1V4nrx8SRcbeLuL4y+AsxHzQVnmRdp1Nm6KCf2pQXlmBl1/2M15odIqgyNjbxWAtGZyFmA/aKi/SrnPoUFTwZ3krtUpaf1HoyaeNoE2ybAmyEQgxDkKbHArNR7BOFy+nqLAPLYMIIdolxApfliJoxVhsZjcAtwMT4E7n3Gdm9r8PuBv4Xrzpy865f1Xm3DSaGIu7QgZhIcZN1TJgdTVyI33++aib+MCB9ruDsozFG1u48AT4AnA90UT23zWze5xz/3vm0P/pnPsHNc8dHDIICzFuqpQB04Fp07EE04FpMJzoo9cATznnnnHO/Rz4EnBTB+cGjQzCQoybKmVA3sC0LgaftqEILgJeSKyfiLfNcq2ZPWJm95rZlRXPxcz2mNmama2dOnWqBbH90oUbqhAiXKqUAVmth2nLYAhB5yxl26zh4SFg0Tn3DuDzwJ9XODfa6NxB59ySc25p69atdWXtDM0ZIMS42bUrmt94MonWJ5NoPa0MyGo9TCbdRDtuQxGcAC5JrF8MnEwe4Jz7sXPuJ/HyEeB1ZralzLlDpihqoRBiflldhcOH4cyZaP3MmWg9rTaf1XqYnjtLiEHnvgtcbmaXmtkm4KPAPckDzOwtZmbx8jXxfU+XOVcIIYZIlYB0WT0Ii4vp127b1tjYa8g596qZfRK4j8gF9JBz7jEz2xvv/4/AbwPLZvYq8DPgo7FPa+q5TWUSQoi+yYodlLV91670XoOkNxH4sTWOLuicEEJ0wcaN6V07kwm8+mr567Q5vsDbOAIhhBDnktW/n7U9i6yWQpuMLuicEEJ0QVb/ftb2PpEiEEIIDwxpLJEUgZhv2hqWqblFRUWGNJZIiiAFffNzQt2ZxX1dR4yOoYwlktfQDLPBnyBqzoWqyUUOvucRVChZMTA0VWVJqgwCEYFTNfxjVjNQoWTFnCNFMIO++QCp21dXNvxjUdePQsmKOUeKYAZ984HRpH++rNtGUTNwSO4fQtRAimAGffOB0aSvrqzbRlEzsG33D3kjiAI6zyJp81eGnprOWVyE5hMOCLOzJ32dJrP27rG4mH6PxcX27jElxIlsRVD4zCJkzFmsFkEKQ3H5GgVt9NUVVa+aNAOrVt3kjSAK6CWLpGmH0JPvFoEIiKbVo7LnZzUD85qHdWTrooUjBo3PLEJGi6D3Qr1OkiIYGU366pp0+xQV9HWu3WU3lBgkPrNIliJQ15AIh6xuliZ9dU38gYva6HWuLW8EUUAfWUSKQISBrzAOTWwMRQV9nWsPKQCN6IVeskhaM6FqAm4AjgNPAftS9u8CjsbpW8A7EvueBR4FHiaj2TKb1DU0h/hqDzexMRTJJA8gMTCyytjGLQIzmwBfAG4ErgA+ZmZXzBz2PeDXnXNvBz4NHJzZf51z7p0uJQaGGAm+hnQ3qV4VtdGzrg3NncA11kB0SZp2qJKAa4H7Euu3ArfmHH8e8P3E+rPAlir3VItgDgnBiJpmlK5qqG6jlaCWhvAEHo3FFwEvJNZPxNuy+Dhwb1IXAV8zswfNbE/WSWa2x8zWzGzt1KlTjQQWAdK3ETXLRgHVDNVtOIG36UiuloUoQ5p2qJKAjwB3JtZvAT6fcex1wOPAQmLbhfHvm4FHgPcW3VMtgjml6yHdyftNJu20SNpwAs+6BlSTpUzLQsPoRwUeWwQngEsS6xcDJ2cPMrO3A3cCNznnTicU0cn49yXgK8A1LcgkhkiRm2jV2m3a8dNtZnDLLestgKwZxavaKNoYCZ11rFm1Gn1Ry0IT7ogpadqhSgI2As8AlwKbiGr1V84cs53Io+jdM9tfD7wxsfwt4Iaie6pFMEDq1jyn501r1WX7zdNqw5s2Ofe612XXtotaBGX+Q1s2gqxWQZUWSlHrJAS7jOgUfI4sBnYCTwBPA/vjbXuBvfHyncAPiVxEH54KA1wWK45HgMem5xYlKYKBUbdwTDuvbKGVVchVSUkZq/yHNrpbsmSq0sVUVNAr3MXo8KoIuk5SBAOjbs2zTGGeVWjl9bOXSRs2nF2Ad117LnO/IoWzvJx+jeXlfv6T6J0sRaCRxcI/dccIpM0TPIuvmYRee+3s9a6nrsvyotq5M93Gkda/f+RI+rWn2/v21BLBIEUg/FO3sJ5M8vfnFVoHDkSFZROS7ppdT12XNlht9244fHhdQTp39jmzLqZdT7gjhktaMyH0pK6hgVHXRlBkG2hy/rQ7Jav7ZLbbKYRBXlW7ytT1I2ZAXUOiN+rWPBcXs7eXGdyVd/7URfWOO2BhIf24ZG0/hNpzmW6opMzq+hFlSdMOoSe1CEZCl5PS9F3bL0NRi6DKhDtilCCvITFImhZkZc8fQoGZprCm3lFdyjyEZyVSkSIQ88VYC6O+//dQWk8ilSxFYNG+YbG0tOTW1tb6FkP0xTQ0QjJ8wubN8njxzepq5LmUFo5jarcRQWNmD7qUcP8yFovh0WZ0TlGOqfJtKyaTCAopAtE/VYPJdT24a54p++zTlG8S5xTmesBIEYh+qRMBs+vBXSHTZL6BKs++jJJV9NLhkmY4CD3JWDxH1Bn0NDSDpS8Db9FzKLpvlWdfJYifBqwFC/IaEkFSNwJm394zZfGptPIK8jL3rfLsy0SCLfvuRG9kKQJ5DYl+2bEjPbjcvHih+Px/GzacG28IopHP27cX37eqbKurka3g+eej6//kJ3D69LnHzcu7m0PkNSTCZN7DIPg0bOfZSsrct+qzn51B7vbb5/vdjQgpAtEvfcbw6WJid5+G7byCvMx9qzz7tGcVQvwl0Q5p/UVVE3ADcJxoOsp9KfsN+Fy8/yhwVdlz01IbNoKhdDELT3Q1sbtvw3aWjG3ed2jGeZEJvozFwIRoisrLWJ+z+IqZY3YC98YK4V3At8uem5aaKgLla1HoMdN2QdpHraMtRTaZ5D8rMRiyFEFjY7GZXQvc5pz7zXj91ril8a8Tx/wp8A3n3Bfj9ePA+4AdReem0dRYPO/2SVGCPEPra68pk0B6KI8k02clBoNPY/FFwAuJ9RPxtjLHlDkXADPbY2ZrZrZ26tSpRgJrYKoo7ENvkkm6sD10QdFo4jEO4JtT2lAEafMBzla1so4pc2600bmDzrkl59zS1q1bK4p4NhqYKgo9ZupmkjojpUMlT+nJO2iuaEMRnAAuSaxfDJwseUyZc1tn3j0WgyWkmnKRx0vdTDJPAfGylN5kEj0rCOd9imakGQ6qJGAj8AxwKesG3ytnjvktzjYWf6fsuWlJXkMDpKrxNYQXtLy8biidTKL1IuqOlA6RvHcmj4tBgs8QE0ReQU8QeQDtj7ftBfbGywZ8Id7/KLCUd25RUoiJQMkrvKvEtQmhkKkrw7xNGJ/1TrP+52Si2lXAeFUEXScpggApKjir1JRDKEzryhCKEvPdmsp6n2ohBI0UgfBLUcFZpWANoXuliQx9dmt1pYjKRiMdaktoTslSBAoxIdqhyN2yivE1BLeuIhnyDN+zMXm6DLmQZay++eZ2Dbpp7zMN+WQPgzTtEHpSiyBAytT4y9aUQ+leqWIonbYg+u4bL+qy8RXeQqOPBwHqGhJeabvwDsFrqKqhNIS+8TJdNgsL7d83BOUtCpEiEP4JofDugjKG0r5qwmUnkPHxbsby/gdMliKQjUC0R599411SxlbRV994cqBcHkUD3KoM/psee8st0fpdd833+59DpAiEqEoZQ2mf8UqmCnllJfuYPEVVJUzGPIXUCIwuB+JLEQhRxOwXCWfXum0mZFaX8UqKvJcWFtLPy1NUVcJkzFNIjYDoXL+m9ReFnmQjEF5I6+PuagKbuvKWka2qETfPBjL7P0MY8zGH+BpTiSavFyKHtNj7mzfDr/xKuBO0l50zYXbS+QMH8vvvs65rFpVHU0J/PgOmaLqMumjyejEe6nSuZnVxpBVyEMZAqbJzJlQ14qfZQGaVAKw/L4XybZ2ux1RKEYj5om7natWCPYTJK3yVFmkhurN6Dl55RRPYe6DzUPlp/UWhJ9kIRCZ1O1ezzltYaDZQqo79IMQR2CEEAhwByVe/sBClNk1PaECZGAV1jZdFISXqGIPrFNTLy+f+hxDmbdDIYe908YilCMQ4aFJzbbtQrSpLnhdOCDVvjRz2SheNrixFIK8hET5VvF6yvH/66Leu6vqR5a2Td46YG3x5Cp19LQ9eQ2Z2vpndb2ZPxr/npRxziZn9pZk9bmaPmdnvJfbdZmbfN7OH47SziTxiDqlq/C2ai7hLqhpz8wzWIRinhVf6jL7e1GtoH/CAc+5y4IF4fZZXgX/unHsb0XzFnzCzKxL7/8Q59844HWkoj5g36oxcDSXmUVXXj6wv3kzumCOgc0+hBE0VwU3A4Xj5MPCh2QOccy865x6Kl/8WeBy4qOF9xdAp6+tf1lc+FJL/a/9+2L27fOsky39/795oOet5dRmURnij18ZsmuGgbAJ+NLP+w4LjdwDPA38nXr8NeBY4ChwCzss5dw+wBqxt3769PeuJ6J4q7hFDcltsw+2japgLefPMDV3Y4qnrNQR8HTiWkm6qogiANwAPAv8ose0CYELUMjkAHCqSx8lraPhUKdyHVND5Ulp51x2SohSZ+PBeTiNLETTyGjKz48D7nHMvmtk24BvOubemHPc64KvAfc65P8641g7gq865Xy26r7yGBk5V94iqsXL6wpfbR951wb+rifBOlsPYwgL87GftOcH5ijV0D7A7Xt4N3J1yYwP+DHh8VgnEymPKh4laGmLeqeIeMRQlAP7cPvKu26eriWiNLJPX6dPdRPluqgg+A1xvZk8C18frmNmFZjb1AHoPcAvw91PcRD9rZo+a2VHgOuD3G8ojhkBZ94g2grJ3aUj15faRd90qz1IG5WCpqrdb95VI6y8KPclGMAeU6fhs2v/dh33Bl8Uv77pF9xySnWWkZL2ihYV2TUBoZLEYHFl945C9PUnZeP3zjp7DIEjrBYV2B8prPoIU1FoOnLwBVmVe1tDGIPhCz2EQpI2DTI4tAJhM1m0EbZZXo1UEmnN7ABw4cO58wBC9sDLWMhlSI/QcBs2uXeumoDNnom1tl1ejVQSac3sA7NqV3QVUpjbb55h9X9Rpxs7jcxgZvsur0SoCtZYHwrRNPEuZ2mxIAejaoG4zdt6ewwjxXV6N1lgs+9lACCmsdN8o046Wtl69jMUzqLU8EFSbXUfN2NHiu7warSJQ+TIgQgkr3Tcy+o4W3+XVaBUBZJcvcisVQaJm7NxRpazxWR/a2N6l5oPZLumpPQ7GWxEVgTDNgEOJvSRyCamsGa2xOAvZ44QQXdBHWSNjcUlkjxNCdEFIZY0UwQyyxwkhuiCkskaKYAbZ44QQXRBSWSNFMIPcSoUQXRBSWSOvoRSmUf+EEKJtQpx0r1GLwMzON7P7zezJ+Pe8jOOejWcie9jM1qqeL4QQoVJlLECoUY+bdg3tAx5wzl0OPBCvZ3Gdc+6dM65LVc4XQoigqFqwhxr1uKkiuAk4HC8fBj7U8flCCNEbVQv2IpfRvqIaNFUEFzjnXgSIf9+ccZwDvmZmD5rZnhrnY2Z7zGzNzNZOnTrVUGwhhGhO1bEAeS6jfXYbFSoCM/u6mR1LSTdVuM97nHNXATcCnzCz91YV1Dl30Dm35Jxb2rp1a9XThRCidaqOBchzGe2z26hQETjn3u+c+9WUdDfwAzPbBhD/vpRxjZPx70vAV4Br4l2lzhdCiBCpOhYgz2W0z5HGTbuG7gF2x8u7gbtnDzCz15vZG6fLwAeAY2XPF0KIUCkzFmC23x/So4j2OdK4qSL4DHC9mT0JXB+vY2YXmtmR+JgLgP9lZo8A3wH+u3Puf+SdL4QQoZFlyM0LD12l37/XkcbOucGlq6++2lVlZcW5xUXnzKLflZXKlyh1rTbvI4Tol+n3DNE3HRXnUdq8ufj7np47mxYX8+/nq/wA1lxKmTqKMNRp096awd69cMcd1e6dN4UuaHpdIeaFtG99lqKQ0Rs2REX/LGZRC6JrssJQj0IRZMX9NoO77qpWSOfFEAfNZSDEvJD1rScpKtBDm99k1PMRZFndnavumpVn2Q8pvrgQohllvtsiQ25Wv//OnWFNhzsKRZD3sqoW0nmW/ZDiiwshmlH03ZYx5KZ5Fe3eDYcPhxVvaBSK4MCB6CWkUbWQzrPshxRfXAjRjLTveVqOVAkZPetVdORIgPGG0izIoac6XkPLy8VW/+Vl5yaTaN9kEq2nIa8hIcaBj+95thyaJrPm1y6CDK+h3gv1OqmOInAu/6UuL6e/nCxl4AspEiHmg6xvuapLaZtkKYJReA2VYeNGOHPm3O2TCbz6aqu3yiTPNVXup0IMh1DdzEftNVSGNCWQt90HocYqF0IUkxx5vHt39rcc0hSVU9QiiAmhRRDa4BMhRDnKDD6D/r9ltQgK2LOn2nYfyP1UiGGS1ppPI9RvWYog5o47YHk5agFA9Lu8vB6CoouZg+R+KkQz+prhq8x4pLxvuS+5f0maBTn0VNdrqC4rK5GraZbraVcB7YQQ2RR9pz7J8gRKegRlydGl3MhrqD558UIOHJCnjxAh0Gdcn9VVuPnm9H0hxSOSjaABeTGEsjx9du/uP36IEGMi6zt97jnYssVvt8uuXbCwkL6vyC4QQoyyRorAzM43s/vN7Mn497yUY95qZg8n0o/N7FPxvtvM7PuJfTubyOOLPCNu1ss6c6b/+CFCjIm8Avf0af9xfW6/vZ6NLwQnkaYtgn3AA865y4EH4vWzcM4dd8690zn3TuBq4KdE8xZP+ZPpfufckdnzQyDPiJv3sjQGQIjuSPtO0/D1XdYdHxCEk0ia4aBsAo4D2+LlbcDxguM/AHwzsX4b8AdV79u1sdi5bCNumqEnL35I8joLC1GSYViIdlhZyTfadhnXpwpdOYngI9YQ8KOZ9R8WHH8I+GRi/TbgWeBovO+8MvftQxHksbKyHqwuL35IkdLw7eEgjyQxBoo8eMp48swrtRUB8HXgWEq6qYoiADYBLwMXJLZdAEyIuqgOAIdyzt8DrAFr27dv7+CRVaOMC1iZDOor8FSfrnVCdElRhWvM34CvFkHprqFYcXwtZ/8O4FiZ+4bWIphSVOPOCj/bRZO1z4iHQnRNWhds15WvEMlSBE2NxfcAu+Pl3cDdOcd+DPhicoOZbUusfpiopTFYZiegmDUSlfEC8OUpEIKLmhB5tDm6NvktvvxylLImp9I30Nxr6DPA9Wb2JHB9vI6ZXWhmv/QAMrPN8f4vz5z/WTN71MyOAtcBv99QnqAp8mrw6SkQgouaEFlMg7b5nL5R30AOac2E0FOoXUNl6MtrSDYCETJtdF0Wdc2G/A0M2muorzRkRdAn8hoSodJ0+sayhXzZ6Wi7RLGGatJ1rCEhhF+axtspc36oMwBu2RKNfJ5FsYaEEKOi6ejaMs4QIc4AuLqargRgQLGGhBCiDZpO31hkCF5dTW8xQL9eQ3lKaEixhkRL9D4xhRAtUTcvF7lf55HXoph2CWXRp9dQnhIaTKyhvtK8GYtD9maoigzS80nZ99pnXs6SMW9Ef9/fWZZsCwt+7oe8hsJlXkb9zpNCE+tUea8h5uW8Ef19582uvxkpgoBp6joXCiEWAqI5Vd5rKHk52TooExCy7rXbaPV22YrOUgSyEfTItC/VZXjwDm3Eo8JYzCdV3msIo3dnRymfOXPuMXVH8fsYAd3ENtIWUgQ9kcxQabQVbqJLI3QIhYBonyrvte9JVlZXo2liZ91EASaTeh5JSUJ0QW2FtGZC6GkeuobyDFhtNQ+77n+UjWA+qfpeqxiW2+5iqTJJVB1C6fqqC7IRhEUXGaqPPnt5Dc0nXRTaTSsNRfN9tJHvh24HkyIIjC4y1NBrL6J9QlHUPvJ/nndQWy3Tobd6sxSBbAQ90UVfqvrsw6brQYRdhHouiw/Hgqx8PZm0F0+o6QjoYEnTDqGneWgROOe/djb02ss808e7aVILbzuv+mgRKL8Xg7qG5pusD3WoPs+hdGH4oo++5rpdhT4KWF+F9rznm6Z4UQTAR4DHgNeApZzjbiCa3/gpYF9i+/nA/cCT8e95Ze4rRXA2XdWEsu6zvBy+ITE0+rDf1FU+vpSWCu3u8aUI3ga8FfhGliIAJsDTwGXAJuAR4Ip432enigHYB/xRmftKEZxNV7XLrPvMFmq+vD+G4plRhr48uuooWDkdzA9ZiqCRsdg597hz7njBYdcATznnnnHO/Rz4EnBTvO8m4HC8fBj4UBN5xkpXI3qzrhfp8XWaDrAZwwjlPgZe1TV0yulg/unCa+gi4IXE+ol4G8AFzrkXAeLfN2ddxMz2mNmama2dOnXKm7BDpKsPtcr1fHh/zFPB05f3SZ1wBn2PFhb+KVQEZvZ1MzuWkm4qOnd6iZRtLmVbLs65g865Jefc0tatW6uePtd09aGm3cfS3i7NCu2xFDwhxJgpw9y6TIpfsrHoAOfc+xve4wRwSWL9YuBkvPwDM9vmnHvRzLYBLzW81yiZfpD790c18e3bo0Kz7Q817T47d8Lhw+fOA9uk0O7q/4jy7Nql5z/PtDJ5vZl9A/gD59w5M8qb2UbgCeA3gO8D3wX+sXPuMTP7N8Bp59xnzGwfcL5z7l8U3U+T14fF6qoKbSGGQNbk9Y0UgZl9GPg8sBX4EfCwc+43zexC4E7n3M74uJ3AvyPyIDrknDsQb18A/iuwHXge+Ihz7pWi+0oRCCFEdbwogr6QIhBCiOpkKQLFGhJCiJEjRSCEECNHikAIIUaOFIEQQoycQRqLzewUkDHbby5bgJdbFqcNQpULwpVNclUnVNkkVzWayLXonDtnRO4gFUFdzGwtzWLeN6HKBeHKJrmqE6pskqsaPuRS15AQQowcKQIhhBg5Y1MEB/sWIINQ5YJwZZNc1QlVNslVjdblGpWNQAghxLmMrUUghBBiBikCIYQYOXOnCMzsI2b2mJm9ZmaZLlZmdoOZHTezp+IQ2NPt55vZ/Wb2ZPx7XktyFV7XzN5qZg8n0o/N7FPxvtvM7PuJfTu7kis+7lkzezS+91rV833JZmaXmNlfmtnj8Xv/vcS+Vp9ZVp5J7Dcz+1y8/6iZXVX2XM9y7YrlOWpm3zKzdyT2pb7XjuR6n5n9TeL9/Muy53Yg2x8m5DpmZmfM7Px4n5dnZmaHzOwlMzuWsd9f/kqbyHjICXgb8FbgG8BSxjET4GngMmAT8AhwRbzvs8C+eHkf8EctyVXpurGM/5doAAjAbURzPrT9vErJBTwLbGn6v9qWDdgGXBUvv5Fo7ovpu2ztmeXlmcQxO4F7iWblexfw7bLnepbr3cB58fKNU7ny3mtHcr0P+Gqdc33LNnP8B4G/6OCZvRe4CjiWsd9b/pq7FoFz7nHn3PGCw64BnnLOPeOc+znwJWA69eZNwOF4+TDwoZZEq3rd3wCeds7VGUFdhab/19fzKnVt59yLzrmH4uW/BR5nfU7sNsnLM0l5/7OL+Gvg71o0816Zc73J5Zz7lnPuh/HqXxPNEuibJv/Z5/Oqc/2PAV9s8f6pOOf+Csibj8Vb/po7RVCSi4AXEusnWC88LnDOvQhRIQO8uaV7Vr3uRzk3830ybhIearELpqxcDviamT1oZntqnO9TNgDMbAfw94BvJza39czy8kzRMWXO9SlXko8T1SqnZL3XruS61sweMbN7zezKiuf6lg0z2wzcAPy3xGZfz6wIb/mrcM7iEDGzrwNvSdm13zl3d5lLpGxr7EebJ1fF62wC/iFwa2LzfwA+TSTnp4F/C/zTDuV6j3PupJm9GbjfzP5PXINpRIvP7A1EH+unnHM/jjfXfmZpt0jZNptnso7xkt8K7nnugWbXESmCX0ts9vJeS8r1EFHX509i+82fA5eXPNe3bFM+CHzTnT1zoq9nVoS3/DVIReCce3/DS5wALkmsXwycjJd/YGbbnHMvxs2ul9qQy8yqXPdG4CHn3A8S1/7lspn9J+CrXcrlnDsZ/75kZl8hao7+FQ2eV1uymdnriJTAqnPuy4lr135mKeTlmaJjNpU416dcmNnbgTuBG51zp6fbc96rd7kSChvn3BEzu8PMtpQ517dsCc5pmXt8ZkV4y19j7Rr6LnC5mV0a174/CtwT77sH2B0v7wbKtDDKUOW65/RJxgXhlA8DqZ4FPuQys9eb2Runy8AHEvf39bzKymbAnwGPO+f+eGZfm88sL88k5f0nsXfHu4C/ibu0ypzrTS4z2w58GbjFOfdEYnvee+1CrrfE7w8zu4aoPDpd5lzfssUyvQn4dRL5zvMzK8Jf/mrb8t13IvrgTwD/D/gBcF+8/ULgSOK4nUQeJk8TdSlNty8ADwBPxr/ntyRX6nVT5NpM9DG8aeb8u4BHgaPxS97WlVxE3giPxOmxLp5XBdl+jagZfBR4OE47fTyztDwD7AX2xssGfCHe/ygJr7Ws/NbScyqS607gh4nns1b0XjuS65PxfR8hMmK/u4vnVUa2eP13gC/NnOftmRFV/l4EfkFUhn28q/ylEBNCCDFyxto1JIQQIkaKQAghRo4UgRBCjBwpAiGEGDlSBEIIMXKkCIQQYuRIEQghxMj5/w62GGID6HZlAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试集 100 个数据点的可视化：\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAgsklEQVR4nO3df+xdd33f8efLDpEwRJA4TnB+fZ1WGVqoSkq+Mj9atSBIl7iiBmloSb9QT2PysimoTGu3ZJFQtMoSo6MTVJTKUISJvyJiApqoGEJgZSlDQL6J8sNpauJYjmPsJd84CJoRLU383h/nXHx8fX+c+73n93k9pKt77/lxz+eee+55f87n11FEYGZm/bWu7gSYmVm9HAjMzHrOgcDMrOccCMzMes6BwMys586qOwFrcf7558eWLVvqToaZWavcf//9z0bEpuHprQwEW7ZsYWVlpe5kmJm1iqQnR0130ZCZWc85EJiZ9ZwDgZlZzzkQmJn1nAOBmVnPFRIIJH1O0jOS9o+ZL0mflHRQ0sOS3pSZd62kA+m8m4tIj5nNZ3kZtmyBdeuS5+XlulNkZSrqiuDzwLUT5l8HXJE+dgKfBpC0HvhUOv9K4AZJVxaUJjNbg+Vl2LkTnnwSIpLnnTsdDLqskEAQEfcCz01YZDvwhUh8H3itpM3AVuBgRByKiBeBO9JlW8s5KWu7W2+Fn//89Gk//3ky3bqpqjqCi4GnMu+PptPGTT+DpJ2SViStrK6ulpbQeTgnZV1w5Mhs0639qgoEGjEtJkw/c2LE7ohYjIjFTZvO6CHdCG3JSfmqpV5N3/+XXTbb9Lar4/do2jFQ1RATR4FLM+8vAY4BZ4+Z3kptyEkNrloGAWtw1QKwtFRfuvqi6ft/eRmef/7M6Rs2wK5d1aenbHX8Ho08BiKikAewBdg/Zt7vAF8nuQJ4C/DDdPpZwCHgcpKg8BDwhmnbuvrqq6OJFhYikkKh0x/r10fs3Vt36hLj0riwUOx29u5NPlNKnpvy/etW1f5fi717IzZsODNtGzd29/fL+3sUcTwPPmPU9qo6BoCVGHWOHjVx1gfwReA48I8kuf8PAjcCN6bzRdI66AngEWAxs+424EfpvFvzbK+IQFDGiWrcHwmS6U34M0mj0ycVt41R+6Ep379uVez/tWpykJrXuP97nt+jiON50rmhymOg1EBQ9WPeQFDmiWrv3uQKoKl/qCr+7F0+ocyryfumyUFqHpP+73l+jyJ+s0lXAp25Iqj6MW8gKPvP2OQ/VBW59SZ//7o1+WqpyUFqHpO+V57fo4jjedxnVH0MOBBklH2iavofquzy+6Z//7o1tf6kiiBVx3ef9n+flqayrwiqPAYcCDLKPlE1OddXhb5//zYr80Rd13Ex7/+9rDqCOv4TDgQZXc35NEnfv7+dqa4rxaJO5EW1GqrzPzEuECiZ1y6Li4sx760ql5eTjl5HjiQdZbZtg337Tr3ftasZ7brNumLduuQ0PEyCkyfL3fbw/72v/29J90fE4vD03g5DvbQEhw8nB+CuXbBnT7FDQ1TRc7BpvRPNJqmzx3L2/374cD+DwCS9DQRZRQ8NUcWYQx7XyNpm166kh3LWcI9lZ25qMqq8qOmPonsWF92KaK1lobOUIbpljrXRpGO8KRWqXYbrCMbbsiXJUQ9bWEguI2e1lrLQ4fFHIMkt7d49+jK2zvJWszIU/T+0M7mOYII8l6yzWEtZ6KzFU30bIdK6rw2DNnaVAwFJjnv37iTnISXP43LieawlsMz6Jyg6eJnVzZmb+jgQpIpsVbCWwDLrn6Do4GVWt7IyN66AzmFUxUHTH00dhnoerigzK77Tlf9Xp2NMZbGvCBrCOXyz4tv713nXwDZdibjVkJl1Vl2t62ZtBVgVtxoys96pqwK6LfcvHygkEEi6VtIBSQcl3Txi/h9JejB97Jf0sqTz0nmHJT2SznM238wKU1frurY1hZ07EEhaT3IbyuuAK4EbJF2ZXSYi/iQiroqIq4BbgP8VEc9lFnlHOv+MSxYzs7Wqq+6tbU1hi7gi2AocjIhDEfEicAewfcLyN5Dc47jRyqzoaVMlklnb1THgXNv6+RQRCC4Gnsq8P5pOO4OkDcC1wJczkwP4pqT7Je0ctxFJOyWtSFpZXV0tINnjlTmgmweLM+u+trUCnLvVkKT3Af8sIv51+v4DwNaI+NCIZf8F8P6IeHdm2kURcUzSBcA9wIci4t5J2yy71VCZY554PBUzq0uZrYaOApdm3l8CHBuz7PUMFQtFxLH0+RngqyRFTbUYFNmMOlFDMRU9batEMrPuKyIQ3AdcIelySWeTnOzvGl5I0muA3wLuzEx7laRzBq+B3wb2F5CmmWWLbMYpoqKnbZVIZtZ9cweCiHgJuAm4G3gM+FJEPCrpRkk3ZhZ9L/DNiPi/mWkXAt+V9BDwQ+BrEfGNedO0FqPa/WYVVdHTtkokM6tPVQ1LziriQyJiH7BvaNpfDL3/PPD5oWmHgDcWkYZ5TSqaWVgo9h6nr3zlqaCzcSN84hPNrUQys3oM904eNCyB4s8X7lmcGlc0M6jELWLHD37YEydOTXvhhfk/18y6p8reyQ4EqSqKbNrW7dzM6lNlwxIHglQV7X7dYsjM8qqyYYkDQUbZPRDdYsjM8qqyYYkDQYXcYsjM8qqyd3IhrYYsn8EPeOutSXHQZZcV2xrJzLplaama84OvCCpWxwBYZlaNtg4o6SsCM7MCVNnuv2i+IjAzK0Cbm4c7EJiZFaDNzcMdCMzMCnDeebNNbxIHAjOznnMgMDMrwHPPzTa9SRwIzMwKkHfkgCY2MXUgMDMrQJ6RA5p6z3IHghk1MZqbWf3yDAnR1CamhQQCSddKOiDpoKSbR8x/u6SfSnowfXwk77pN0tRobmbNMG3kgKY2MZ07EEhaD3wKuA64ErhB0pUjFv3biLgqffyXGddthKZGczNrh6aOQFzEFcFW4GBEHIqIF4E7gO0VrFu5pkZzM2uHpo5AXEQguBh4KvP+aDpt2FslPSTp65LeMOO6SNopaUXSyurqagHJnl1To7mZtUOVQ0vPoohAoBHTYuj9A8BCRLwR+DPgr2ZYN5kYsTsiFiNicdOmTWtN61yaGs3NrD2aOAJxEYHgKHBp5v0lwLHsAhHxs4h4Pn29D3iFpPPzrNskTY3mZmbzKCIQ3AdcIelySWcD1wN3ZReQ9DpJSl9vTbd7Is+6TdPEaG4N5bbG1hJz348gIl6SdBNwN7Ae+FxEPCrpxnT+XwD/HPi3kl4CXgCuj4gARq47b5rMatfmwemtd5Scj9tlcXExVlZW6k6G2XhbtiQn/2ELC8mlpFkNJN0fEYvD092z2KwMbmtsLeJAYFYGtzW2FnEgMCuD2xpbizgQmJXBbY2tRRwIzMoyrq2xm5Vaw8zdfNTMZuBmpdZAviIwm6bIHLyHsLUGciAwm6Tom1BMa1bqYiOrgQOB2SRF5+AnNSv1nY+sJg4E1l5V5J6L7hg2qVmpi42sJg4E1k5V5Z6L7hg2qVmpeyNbTRwIrJ2qyj2X0TFsXLNS90a2mjgQWDtVlXsuq2PYqGIt90a2mjgQWDtVmXsu+iYU44q1wL2RrRYOBNYOwznobdvam3ueVKzlOx9ZRlWtiR0IrPlG5aD37IEdO07lnjduhFe+Ej7wgea3v3elsOVQZWviQgKBpGslHZB0UNLNI+YvSXo4fXxP0hsz8w5LekTSg5J8txk707gc9L59Sa759tvhhRfgxIl2tL8vuljLndA6qdLWxBEx14PkFpNPAL8EnA08BFw5tMzbgHPT19cBP8jMOwycP8s2r7766rAekSKSU/zpDymZv7Awev7CQp2pHm/v3ogNG05P64YNyfRZPmPwvYf3z6yfZY007bBfC2AlRpxTi7gi2AocjIhDEfEicAewfSjYfC8ifpK+/T5wSQHbLZQzVQ02LQfdtqKWeVsiZcsMIDk/ZLkTWidU2R6iiEBwMfBU5v3RdNo4HwS+nnkfwDcl3S9p57iVJO2UtCJpZXV1da4ED3PP/oab1qyyje3v56kUHlVmMGzWIOicUONU2pp41GXCLA/gfcBnM+8/APzZmGXfATwGbMxMuyh9voCkWOk3p22z6KKhtpUs9NKgKERKnrNFH0UUtVSV1iKMKzNY68E7av9BxMaNLmKqWdGHEmOKhooIBG8F7s68vwW4ZcRyv0pSl/BPJnzWbcAfTttm0YGgjLI4q1jZJ99Z0lF2UBqXc1nr9iZ9nusbalfkoV1mIDgLOARczqnK4jcMLXMZcBB429D0VwHnZF5/D7h22jZ9RWCNVcXBNCrYDHIzazlTTLvC8B+hNkXnK8YFgrnrCCLiJeAm4O602OdLEfGopBsl3Zgu9hFgI/DnQ81ELwS+K+kh4IfA1yLiG/OmaVbu2W+FqaLielRl8+23J+eJtXRCm1aX0tRK9x6oqgmpkiDRLouLi7GyUmyXg+XlZOceOZL8L3btcqdOW4MtW0615slaWEhO0k00fPvMYU1Oe8etW3dmozBI4v/Jk7N/nqT7I2LxjO2sJXFd5J79Vog2Xl4OrjA2bjxzXtPT3nFVNYhzIDAbKKIJZVmjlZZtaQmefRb27m1f2jusqnyFi4bMYHTxyIYNPhFa7Yosth5XNORAYAbtLNs3m5HrCMwmadswFWYFciCwfhvUC4y7Mm7yMBVZHiLC5nBW3Qkwq820ZpNtaTEz/D2ydzxz/Ybl4CsC669Jg7e1qcVMpQPXWxf5isD6a1z5v9SuCmLXb9icfEVg/dXG4atH6cr3sNo4EFh/tbEX8Ch5vocrk20CBwLrrzJ6Addxwp32PXznJZvCHcrMitLU3snuLGcpdygzK1tTW++4MtmmcCAwK0pTT7iuTLYpHAjMitLUE24dleKunG6VQgKBpGslHZB0UNLNI+ZL0ifT+Q9LelPedc1ao6mtkKoeGtuV060zd2WxpPXAj4BrgKPAfcANEfF3mWW2AR8CtgFvBj4REW/Os+4oriy2xvKt7lw53WDjKouL6Fm8FTgYEYfSDd0BbAeyJ/PtwBfSmyd/X9JrJW0GtuRY16w9lpb6d+If1tS6EhuriKKhi4GnMu+PptPyLJNnXQAk7ZS0ImlldXV17kSbtUbbytubWldiYxURCDRi2nB507hl8qybTIzYHRGLEbG4adOmGZNo1lJtLG9val2JjVVEIDgKXJp5fwlwLOcyeda1vmtbjrhITe2bMElb79vcY0XUEdwHXCHpcuDHwPXA7w0tcxdwU1oH8GbgpxFxXNJqjnWtz/o+1n5by9tdV9Iqc18RRMRLwE3A3cBjwJci4lFJN0q6MV1sH3AIOAh8Bvh3k9adN03WIW3MERfJ5e1WAY81ZM22bt3o20hKcPJk9empWlPHL7JW8lhD1k59zxG7vN0q4EBgzTaqBcorXgHPP9+fyuOlpaQj1smTybODgBXMgcCabThHvHFj8nziRHuaU5rlVFcDOQcCa75sjvjVr4YXXzx9flMqj/vczNXmVmeXEQcCa5cmNqdcXobzz4f3v79dHb9m5UBXqjobyDkQWLs0rfJ4kI07ceLMeU25UilCG3s4t8yocfomTS+SAwHO6LRK04YvGJWNy2p6x6+8+t6fowLr1882vUi9DwTO6LTMuOaUUE80n3ai70oz1yYWyXXMyy/PNr1IvQ8Ezui00HBzSqgvmk860XdpoLWmFcl10MLCbNOL1PtA4IxOB9QZzUcVVUHSzLVLHb+aViTXQdu2JRe5WVXt4t4HAmd0OqDOaD6qqGrvXnj22e4EAXAP55ItL8OePaePpiLBjh3V7OLejzXkoVw6wLdGtJar6hD2WENjOKPTAS62sJaru4i694EAPJRL6zmaW8vVXUTtQGDd4GhuLVb3Re1cgUDSeZLukfR4+nzuiGUulfQ3kh6T9KikP8jMu03SjyU9mD62zZMes0q5J6IVpO6L2rkqiyV9DHguIj4q6Wbg3Ij4T0PLbAY2R8QDks4B7gfeExF/J+k24PmI+G+zbNc3prHauZWBtVBZlcXbgT3p6z3Ae4YXiIjjEfFA+vofSG5JefGc2zWrl3siWofMGwgujIjjkJzwgQsmLSxpC/BrwA8yk2+S9LCkz40qWjJrpLqbeZgVaGogkPQtSftHPLbPsiFJrwa+DHw4In6WTv408MvAVcBx4OMT1t8paUXSyurq6iybNite3c08zAo0NRBExLsi4ldGPO4Enk7rAAZ1Ac+M+gxJryAJAssR8ZXMZz8dES9HxEngM8DWCenYHRGLEbG4adOm2b6lWdHqbuZhVqB5i4buAnakr3cAdw4vIEnAXwKPRcSfDs3bnHn7XmD/nOkxq0bdzTys86pslDZvq6GNwJeAy4AjwPsi4jlJFwGfjYhtkn4D+FvgEeBkuup/joh9km4nKRYK4DDwbwZ1DpO41ZCZdVlZjdLGtRrq/VhD0ywvJw1BjhxJin937XKmz8zKVdbYQx5raA1805qOcMcva5mqG6U5EEzgpuId4GhuLVR1ozQHggncVLwDHM2thapulOZAMIGbijdcniKfLkVzF3H1RtWN0hwIJnBT8QbLW+TTlWjuIq7eqXJAXQeCCdxUvMHyFvl0JZq7iMtK5Oaj1k7r1p1+g9cBKclCZXWhDfAs39dsjHHNR8+qIzFmc7vsstENrUcV+Swtte/EP2yW72s2IxcNWTt1pcgnr759X6uUA4G1U98qcPr2fa1SriMoQBeKoM2s+zzExJCimmS7VZ+ZtV0vA0GRJ2+36jOztutlICjy5N2ljqtm1k+9DARFnry70nHVzPqrl4GgyJO3W/WZWdvNFQgknSfpHkmPp8/njlnusKRHJD0oaWXW9YtW5MnbrfrMrO3mvSK4Gfh2RFwBfDt9P847IuKqoaZLs6xfmKJP3lUODmVmVrR571l8AHh7RBxPb0T/nYh4/YjlDgOLEfHsWtYf1rR+BGZmbVBWP4ILBzebT58vGLNcAN+UdL+knWtY38zMSjJ10DlJ3wJeN2LWLI0tfz0ijkm6ALhH0t9HxL0zrE8aQHYCXOYmOWZmhZkaCCLiXePmSXpa0uZM0c4zYz7jWPr8jKSvAluBe4Fc66fr7gZ2Q1I0NC3dZmaWz7xFQ3cBO9LXO4A7hxeQ9CpJ5wxeA78N7M+7vpmZlWveQPBR4BpJjwPXpO+RdJGkfekyFwLflfQQ8EPgaxHxjUnrm5lZdeYKBBFxIiLeGRFXpM/PpdOPRcS29PWhiHhj+nhDROyatn7b5BnAzvcdN7Om8h3K5jQYwG4wdtFgADs41Z8gzzJmZnXpzRATZeXI8wxg5xFKzazJenFFUGaOPM8Adh6h1MyarBdXBGXmyPMMYOcRSs26qwv1f70IBGXmyPMMYLdtWzKm0aRlzKx9unKHwl4EgjJz5NMGsFtehj17koNkQIIdO1xRbNZ2Xan/68XN64frCCDJkVcxXPSWLUkuYdjCQjJSqZm117p1p2fyBqRkNOKm6fXN6+u8Z4Aris26qyv1f70IBFDfPQO6cqCY2ZlmvcnVPBXLZVZK9yYQ1MW3sjTrrllKG+apWC67UroXdQR1W15OKo+OHEmuBHbtckWxWd/MU19YVF3juDoCBwIzswrMU7FcVKV0ryuLu6QLnVfM+mie+sKy6xodCFqkK51XzPponvrCsusaHQhapCudV8z6aJ5m7GU3gXcdQYu0rfOKmTVLKXUEks6TdI+kx9Pnc0cs83pJD2YeP5P04XTebZJ+nJm3bZ70dJ37JJhZGeYtGroZ+HZEXAF8O31/mog4EBFXRcRVwNXAz4GvZhb574P5EbFveH07xX0SzKwM8waC7cCe9PUe4D1Tln8n8EREjGgRa9PUOVSGmXXXvDemuTAijgNExHFJF0xZ/nrgi0PTbpL0+8AK8B8i4idzpqnTlpZ84jezYk29IpD0LUn7Rzy2z7IhSWcDvwv8j8zkTwO/DFwFHAc+PmH9nZJWJK2srq7OsmkzM5tgaiCIiHdFxK+MeNwJPC1pM0D6/MyEj7oOeCAins589tMR8XJEnAQ+A2ydkI7dEbEYEYubNm3K+/3MzNasLx04560juAvYkb7eAdw5YdkbGCoWGgSR1HuB/XOmx8wst0kn+j514Jw3EHwUuEbS48A16XskXSTpFy2AJG1I539laP2PSXpE0sPAO4B/P2d6bIq+5HDMppl2ou9TB053KJtiMHLok0/C+vXw8stJa502jiBa553azJpm2oieXezA6UHn1iCbY4AkCEB7LxH7lMMxm2bS3QOXl5NAMEpZHTjrvFr3FcEE43IMA22773AXczhmazXu/71xI7zwwpmZJijvCrqqq3VfEazBtPsKt+2+wx6iwuxUzvvJJ5NMUNag5/6oILB+fXlBYMeOeq/WHQgmmHaCbNsJ1ENUWN8NF/dGnAoGg576zz03et2TJ2cPAtOKewbpGRQ7D6sssxkRrXtcffXVUYW9eyM2bIhIDpfTHxs2JPPbZu/eiIWFCCl5buN3MFurhYXR/+eFhdmWyWPU+WP4vDFuW2vd5jTASow4p9Z+Ul/Lo6pAEHHqxAkR69ef+nHqOoH6RG62dtLoE650apk8J/A88gSUcekpK7PpQNABRR2gdXIgszrlze0XcZzmCTrj0rN+fTn/DQeCDijqknWg6pNyFwKZtVuVx2Ce/2vV/wkHgg7Ik8PIq46TctGBzGwtqsoA5f2PVZkhGxcI3I+gRab1hKzrs/JyPwbrm8HIBEeOJK0M6x6RwP0IOqDI5p+TelWWxf0YrG+WlpKM1cmTyXNTh3JxIGiRIu9QVsdJeVQgk2Cb71RtVisHgpYpKoeR5+qi6LFPlpaSHpTZ3pwRsGdP+8ZtajuPQmunGVVx0PRHXyuLizapkqqsyuSmVBj3uRlrF1tv9fn3nAVuNVS+Lh2MZZ2wi2z5tFZdPBHOoinBuCh9/z1nMS4QuGioIF27m1FZlclNqDDu+3Dcs/62TS9G6uLvWfk+HxUd8j6A9wGPAieBxQnLXQscAA4CN2emnwfcAzyePp+bZ7tNvCLoWi6rrO/ThNxbE65K6jTLb9uE32uarv2eZe5zyigaAv4p8HrgO+MCAbAeeAL4JeBs4CHgynTexwaBAbgZ+K95ttvEQOCDcbbPrrMIrWtBe1az/LZt2FdtSOMsyvw+pQSCX3zI5EDwVuDuzPtbgFvS1weAzenrzcCBPNtrYiDo2sEYUf8JuyxtyOWWLe9v24YMTtd+zzL3+bhAUEUdwcXAU5n3R9NpABdGxHGA9PmCCtJTii6O9d+WzjCzKrI/Rlvl/W2bUKczTdd+zzr2+dRAIOlbkvaPeGzPuQ2NmBazJRMk7ZS0ImlldXV11tVL17WDseu6GuSK1pYMTpd+zzr2+VnTFoiId825jaPApZn3lwDH0tdPS9ocEcclbQaemZCO3cBuSMYamjNNpVhaavcBaDZscDw3abycrqtjnxcy6Jyk7wB/GBFnjAQn6SzgR8A7gR8D9wG/FxGPSvoT4EREfFTSzcB5EfEfp22vr4POmZnNo5RB5yS9V9JRkgrhr0m6O51+kaR9ABHxEnATcDfwGPCliHg0/YiPAtdIehy4Jn1vZmYV8jDUZmY94WGozcxsJAcCM7OecyAwM+u5VtYRSFoFRtxoMZfzgWcLTE5RnK7ZNTVtTtdsmpouaG7a1pquhYjYNDyxlYFgHpJWRlWW1M3pml1T0+Z0zaap6YLmpq3odLloyMys5xwIzMx6ro+BYHfdCRjD6ZpdU9PmdM2mqemC5qat0HT1ro7AzMxO18crAjMzy3AgMDPruU4GAknvk/SopJOSxjaxknStpAOSDqajnw6mnyfpHkmPp8/nFpSuqZ8r6fWSHsw8fibpw+m82yT9ODNvW1XpSpc7LOmRdNsrs65fRrokXSrpbyQ9lv7mf5CZV+j+Gne8ZOZL0ifT+Q9LelPedeeVI21LaZoelvQ9SW/MzBv5u1aUrrdL+mnmN/pI3nVLTtcfZdK0X9LLks5L55W5vz4n6RlJ+8fML+cYG3XbsrY/qOleyjnSNdPnpmn8PySdQABuIxnuu+j9lStdwGHg/Hm/V5HpIrnF6ZvS1+eQDHk++B0L21+TjpfMMtuAr5PcjOktwA/yrltB2t4GnJu+vm6Qtkm/a0Xpejvw12tZt8x0DS3/buB/lr2/0s/+TeBNwP4x80s5xjp5RRARj0XEgSmLbQUORsShiHgRuAMY3HVtO7Anfb0HeE9BSZv1c98JPBERa+1Fnde837e2/RURxyPigfT1P5AMdX7x8HIFmHS8ZNP7hUh8H3itkhsu5Vm31LRFxPci4ifp2++T3CCqbPN87zL32ayffQPwxYK2PVFE3As8N2GRUo6xTgaCnOq4l/Ksn3s9Zx6AN6WXhJ8rqghmhnQF8E1J90vauYb1y0oXAJK2AL8G/CAzuaj9Nel4mbZMnnXnMevnf5AkVzkw7netKl1vlfSQpK9LesOM65aZLiRtAK4FvpyZXNb+yqOUY2zqrSqbStK3gNeNmHVrRNyZ5yNGTJu7Le2kdM34OWcDvwvckpn8aeCPSdL5x8DHgX9VYbp+PSKOSboAuEfS36c5mDUrcH+9muTP+uGI+Fk6ec37a9QmRkwbPl7GLVPKsZZju2cuKL2DJBD8RmZy4b/rDOl6gKTo8/m0DuevgCtyrltmugbeDfzviMjm0svaX3mUcoy1NhBEQ+6lPEu6JM3yudcBD0TE05nP/sVrSZ8B/rrKdEXEsfT5GUlfJbkcvZea95ekV5AEgeWI+Erms9e8v0aYdLxMW+bsHOvOI0/akPSrwGeB6yLixGD6hN+19HRlgjYRsU/Sn0s6P8+6ZaYr44yr8hL3Vx6lHGN9Lhq6D7hC0uVp7vt64K503l3AjvT1DiDPFUYes3zuGeWS6clw4L3AyJYFZaRL0qsknTN4Dfx2Zvu17S9JAv4SeCwi/nRoXpH7a9Lxkk3v76ctO94C/DQt0sqz7jymfr6ky4CvAB+IiB9lpk/6XatI1+vS3xBJW0nOSSfyrFtmutL0vAb4LTLHXcn7K49yjrEyar7rfpD86Y8C/w94Grg7nX4RsC+z3DaSViZPkBQpDaZvBL4NPJ4+n1dQukZ+7oh0bSD5M7xmaP3bgUeAh9MfeXNV6SJpjfBQ+ni0KfuLpIgj0n3yYPrYVsb+GnW8ADcCN6avBXwqnf8ImRZr4461Ao/5aWn7LPCTzD5amfa7VpSum9LtPkRSif22KvbZtHSl7/8lcMfQemXvry8Cx4F/JDmHfbCKY8xDTJiZ9Vyfi4bMzAwHAjOz3nMgMDPrOQcCM7OecyAwM+s5BwIzs55zIDAz67n/D8qKOkqmMigOAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " 读者不妨自己调节数据集的参数设置来生成属于自己的数据集吧！\n"
     ]
    }
   ],
   "source": [
    "# 数据集参数设置\n",
    "Ntrain = 200        # 规定训练集大小\n",
    "Ntest = 100         # 规定测试集大小\n",
    "boundary_gap = 0.5  # 设置决策边界的宽度\n",
    "seed_data = 2       # 固定随机种子\n",
    "\n",
    "# 生成自己的数据集\n",
    "train_x, train_y, test_x, test_y = circle_data_point_generator(Ntrain, Ntest, boundary_gap, seed_data)\n",
    "\n",
    "# 打印数据集的维度信息\n",
    "print(\"训练集 {} 个数据点的可视化：\".format(Ntrain))\n",
    "data_point_plot(train_x, train_y)\n",
    "print(\"测试集 {} 个数据点的可视化：\".format(Ntest))\n",
    "data_point_plot(test_x, test_y)\n",
    "print(\"\\n 读者不妨自己调节数据集的参数设置来生成属于自己的数据集吧！\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据的预处理\n",
    "\n",
    "与经典机器学习不同的是，量子分类器在实际工作的时候需要考虑数据的预处理。我们需要多加一个步骤将经典的数据转化成量子信息才能放在量子计算机上运行。接下来我们看看具体是怎么完成的。\n",
    "\n",
    "首先我们确定需要使用的量子比特数量。因为我们的数据 $\\{x^{k} = (x^{k}_0, x^{k}_1)\\}$ 是二维的, 按照 Mitarai (2018) 论文[1]中的编码方式我们至少需要2个量子比特。接着准备一系列的初始量子态 $|00\\rangle$。然后将经典信息 $\\{x^{k}\\}$ 编码成一系列量子门 $U(x^{k})$ 并作用在初始量子态上。最终得到一系列的量子态 $|\\psi\\rangle^k = U(x^{k})|00\\rangle$。这样我们就完成从经典信息到量子信息的编码了！给定 $m$ 个量子比特去编码二维的经典数据点，则量子门的构造为：\n",
    "\n",
    "$$\n",
    "U(x^{k}) = \\otimes_{j=0}^{m-1} R_j^z\\big[\\arccos(x_{j \\, \\text{mod} \\, 2}\\cdot x_{j \\, \\text{mod} \\, 2})\\big] R_j^y\\big[\\arcsin(x_{j \\, \\text{mod} \\, 2}) \\big],\\tag{2}\n",
    "$$\n",
    "\n",
    "**注意** ：这种表示下，我们将第一个量子比特编号为 $j = 0$。更多编码方式见 [Robust data encodings for quantum classifiers](https://arxiv.org/pdf/2003.01695.pdf)。读者也可以直接使用量桨中提供的[编码方式](./DataEncoding_CN.ipynb)。这里我们也欢迎读者自己创新尝试全新的编码方式。\n",
    "由于这种编码的方式看着比较复杂，我们不妨来举一个简单的例子。假设我们给定一个数据点 $x = (x_0, x_1)= (1,0)$, 显然这个数据点的标签应该为 1，对应上图**蓝色**的点。同时数据点对应的2比特量子门 $U(x)$ 是\n",
    "\n",
    "$$\n",
    "U(x) = \n",
    "\\bigg( R_0^z\\big[\\arccos(x_{0}\\cdot x_{0})\\big] R_0^y\\big[\\arcsin(x_{0}) \\big]  \\bigg)\n",
    "\\otimes \n",
    "\\bigg( R_1^z\\big[\\arccos(x_{1}\\cdot x_{1})\\big] R_1^y\\big[\\arcsin(x_{1}) \\big] \\bigg),\\tag{3}\n",
    "$$\n",
    "\n",
    "\n",
    "把具体的数值带入我们就能得到：\n",
    "$$\n",
    "U(x) = \n",
    "\\bigg( R_0^z\\big[0\\big] R_0^y\\big[\\pi/2 \\big]  \\bigg)\n",
    "\\otimes \n",
    "\\bigg( R_1^z\\big[\\pi/2\\big] R_1^y\\big[0 \\big] \\bigg),\n",
    "\\tag{4}\n",
    "$$\n",
    "\n",
    "以下是常用的旋转门的矩阵形式：\n",
    "\n",
    "\n",
    "$$\n",
    "R_x(\\theta) := \n",
    "\\begin{bmatrix} \n",
    "\\cos \\frac{\\theta}{2} &-i\\sin \\frac{\\theta}{2} \\\\ \n",
    "-i\\sin \\frac{\\theta}{2} &\\cos \\frac{\\theta}{2} \n",
    "\\end{bmatrix}\n",
    ",\\quad \n",
    "R_y(\\theta) := \n",
    "\\begin{bmatrix}\n",
    "\\cos \\frac{\\theta}{2} &-\\sin \\frac{\\theta}{2} \\\\ \n",
    "\\sin \\frac{\\theta}{2} &\\cos \\frac{\\theta}{2} \n",
    "\\end{bmatrix}\n",
    ",\\quad \n",
    "R_z(\\theta) := \n",
    "\\begin{bmatrix}\n",
    "e^{-i\\frac{\\theta}{2}} & 0 \\\\ \n",
    "0 & e^{i\\frac{\\theta}{2}}\n",
    "\\end{bmatrix}. \\tag{5}\n",
    "$$\n",
    "\n",
    "那么这个两比特量子门 $U(x)$ 的矩阵形式可以写为：\n",
    "\n",
    "$$\n",
    "U(x) = \n",
    "\\bigg(\n",
    "\\begin{bmatrix}\n",
    "1 & 0 \\\\ \n",
    "0 & 1\n",
    "\\end{bmatrix}\n",
    "\\begin{bmatrix}\n",
    "\\cos \\frac{\\pi}{4} &-\\sin \\frac{\\pi}{4} \\\\ \n",
    "\\sin \\frac{\\pi}{4} &\\cos \\frac{\\pi}{4} \n",
    "\\end{bmatrix}\n",
    "\\bigg)\n",
    "\\otimes \n",
    "\\bigg(\n",
    "\\begin{bmatrix}\n",
    "e^{-i\\frac{\\pi}{4}} & 0 \\\\ \n",
    "0 & e^{i\\frac{\\pi}{4}}\n",
    "\\end{bmatrix}\n",
    "\\begin{bmatrix}\n",
    "1 &0 \\\\ \n",
    "0 &1\n",
    "\\end{bmatrix}\n",
    "\\bigg),\\tag{6}\n",
    "$$\n",
    "\n",
    "化简后我们作用在零初始化的 $|00\\rangle$ 量子态上可以得到编码后的量子态 $|\\psi\\rangle$，\n",
    "\n",
    "$$\n",
    "|\\psi\\rangle =\n",
    "U(x)|00\\rangle = \\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "1-i &0 &-1+i &0 \\\\ \n",
    "0 &1+i &0  &-1-i \\\\\n",
    "1-i &0 &1-i  &0 \\\\\n",
    "0 &1+i &0  &1+i \n",
    "\\end{bmatrix}\n",
    "\\begin{bmatrix}\n",
    "1 \\\\\n",
    "0 \\\\\n",
    "0 \\\\\n",
    "0\n",
    "\\end{bmatrix}\n",
    "= \\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "1-i \\\\\n",
    "0 \\\\\n",
    "1-i \\\\\n",
    "0\n",
    "\\end{bmatrix}.\\tag{7}\n",
    "$$\n",
    "\n",
    "接着我们来看看代码上怎么实现这种编码方式。需要注意的是：代码中使用了一个张量积来表述\n",
    "\n",
    "$$\n",
    "(U_1 |0\\rangle)\\otimes (U_2 |0\\rangle) = (U_1 \\otimes U_2) |0\\rangle\\otimes|0\\rangle\n",
    "= (U_1 \\otimes U_2) |00\\rangle.\\tag{8}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:06.589265Z",
     "start_time": "2021-03-02T09:15:06.452691Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "作为测试我们输入以上的经典信息:\n",
      "(x_0, x_1) = (1, 0)\n",
      "编码后输出的2比特量子态为:\n",
      "[[[0.5-0.5j 0. +0.j  0.5-0.5j 0. +0.j ]]]\n"
     ]
    }
   ],
   "source": [
    "def Ry(theta):\n",
    "    \"\"\"\n",
    "    :param theta: 参数\n",
    "    :return: Y 旋转矩阵\n",
    "    \"\"\"\n",
    "    return np.array([[np.cos(theta / 2), -np.sin(theta / 2)],\n",
    "                     [np.sin(theta / 2), np.cos(theta / 2)]])\n",
    "\n",
    "def Rz(theta):\n",
    "    \"\"\"\n",
    "    :param theta: 参数\n",
    "    :return: Z 旋转矩阵\n",
    "    \"\"\"\n",
    "    return np.array([[np.cos(theta / 2) - np.sin(theta / 2) * 1j, 0],\n",
    "                     [0, np.cos(theta / 2) + np.sin(theta / 2) * 1j]])\n",
    "\n",
    "# 经典 -> 量子数据编码器\n",
    "def datapoints_transform_to_state(data, n_qubits):\n",
    "    \"\"\"\n",
    "    :param data: 形状为 [-1, 2]\n",
    "    :param n_qubits: 数据转化后的量子比特数量\n",
    "    :return: 形状为 [-1, 1, 2 ^ n_qubits]\n",
    "    \"\"\"\n",
    "    dim1, dim2 = data.shape\n",
    "    res = []\n",
    "    for sam in range(dim1):\n",
    "        res_state = 1.\n",
    "        zero_state = np.array([[1, 0]])\n",
    "        for i in range(n_qubits):\n",
    "            if i % 2 == 0:\n",
    "                state_tmp=np.dot(zero_state, Ry(np.arcsin(data[sam][0])).T)\n",
    "                state_tmp=np.dot(state_tmp, Rz(np.arccos(data[sam][0] ** 2)).T)\n",
    "                res_state=np.kron(res_state, state_tmp)\n",
    "            elif i % 2 == 1:\n",
    "                state_tmp=np.dot(zero_state, Ry(np.arcsin(data[sam][1])).T)\n",
    "                state_tmp=np.dot(state_tmp, Rz(np.arccos(data[sam][1] ** 2)).T)\n",
    "                res_state=np.kron(res_state, state_tmp)\n",
    "        res.append(res_state)\n",
    "\n",
    "    res = np.array(res)\n",
    "    return res.astype(\"complex128\")\n",
    "\n",
    "print(\"作为测试我们输入以上的经典信息:\")\n",
    "print(\"(x_0, x_1) = (1, 0)\")\n",
    "print(\"编码后输出的2比特量子态为:\")\n",
    "print(datapoints_transform_to_state(np.array([[1, 0]]), n_qubits=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构造量子神经网络\n",
    "\n",
    "那么在完成上述从经典数据到量子数据的编码后，我们现在可以把这些量子态输入到量子计算机里面了。在那之前，我们还需要设计下我们所采用的量子神经网络结构。\n",
    "\n",
    "![电路结构](figures/qclassifier-fig-circuit.png \"图 3：参数化量子神经网络的电路结构\")\n",
    "<div style=\"text-align:center\">图 3：参数化量子神经网络的电路结构 </div>\n",
    "\n",
    "\n",
    "为了方便，我们统一将上述参数化的量子神经网络称为 $U(\\boldsymbol{\\theta})$。这个 $U(\\boldsymbol{\\theta})$ 是我们分类器的关键组成部分，需要一定的复杂结构来拟合我们的决策边界。与经典神经网络类似，量子神经网络的的设计并不是唯一的，这里展示的仅仅是一个例子，读者不妨自己设计出自己的量子神经网络。我们还是拿原来提过的这个数据点 $x = (x_0, x_1)= (1,0)$ 来举例子，编码过后我们已经得到了一个量子态 $|\\psi\\rangle$，\n",
    "\n",
    "$$\n",
    "|\\psi\\rangle =\n",
    "\\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "1-i \\\\\n",
    "0 \\\\\n",
    "1-i \\\\\n",
    "0\n",
    "\\end{bmatrix},\\tag{9}\n",
    "$$\n",
    "\n",
    "接着我们把这个量子态输入进我们的量子神经网络，也就是把一个酉矩阵乘以一个向量。得到处理过后的量子态 $|\\varphi\\rangle$\n",
    "\n",
    "$$\n",
    "|\\varphi\\rangle = U(\\boldsymbol{\\theta})|\\psi\\rangle,\\tag{10}\n",
    "$$\n",
    "\n",
    "如果我们把所有的参数 $\\theta$ 都设置为 $\\theta = \\pi$, 那么我们就可以写出具体的矩阵了：\n",
    "\n",
    "$$\n",
    "|\\varphi\\rangle = \n",
    "U(\\boldsymbol{\\theta} =\\pi)|\\psi\\rangle =\n",
    "\\begin{bmatrix}\n",
    "0  &0 &-1 &0 \\\\ \n",
    "-1 &0 &0  &0 \\\\\n",
    "0  &1 &0  &0 \\\\\n",
    "0  &0 &0  &1 \n",
    "\\end{bmatrix}\n",
    "\\cdot\n",
    "\\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "1-i \\\\\n",
    "0 \\\\\n",
    "1-i \\\\\n",
    "0\n",
    "\\end{bmatrix}\n",
    "= \\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "-1+i \\\\\n",
    "-1+i \\\\\n",
    "0 \\\\\n",
    "0\n",
    "\\end{bmatrix}.\\tag{11}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:06.795398Z",
     "start_time": "2021-03-02T09:15:06.782149Z"
    }
   },
   "outputs": [],
   "source": [
    "# 模拟搭建量子神经网络\n",
    "def U_theta(theta, n, depth):  \n",
    "    \"\"\"\n",
    "    :param theta: 维数: [n, depth + 3]\n",
    "    :param n: 量子比特数量\n",
    "    :param depth: 电路深度\n",
    "    :return: U_theta\n",
    "    \"\"\"\n",
    "    # 初始化网络\n",
    "    cir = UAnsatz(n)\n",
    "    \n",
    "    # 先搭建广义的旋转层\n",
    "    for i in range(n):\n",
    "        cir.rz(theta[i][0], i)\n",
    "        cir.ry(theta[i][1], i)\n",
    "        cir.rz(theta[i][2], i)\n",
    "\n",
    "    # 默认深度为 depth = 1\n",
    "    # 搭建纠缠层和 Ry旋转层\n",
    "    for d in range(3, depth + 3):\n",
    "        for i in range(n-1):\n",
    "            cir.cnot([i, i + 1])\n",
    "        cir.cnot([n-1, 0])\n",
    "        for i in range(n):\n",
    "            cir.ry(theta[i][d], i)\n",
    "\n",
    "    return cir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 测量与损失函数\n",
    "\n",
    "当我们在量子计算机上（QPU）用量子神经网络处理过初始量子态 $|\\psi\\rangle$ 后， 我们需要重新测量这个新的量子态 $|\\varphi\\rangle$ 来获取经典信息。这些处理过后的经典信息可以用来计算损失函数 $\\mathcal{L}(\\boldsymbol{\\theta})$。最后我们再通过经典计算机（CPU）来不断更新QNN参数 $\\boldsymbol{\\theta}$ 并优化损失函数。这里我们采用的测量方式是测量泡利 $Z$ 算符在第一个量子比特上的期望值。 具体来说，\n",
    "\n",
    "$$\n",
    "\\langle Z \\rangle = \n",
    "\\langle \\varphi |Z\\otimes I\\cdots \\otimes I| \\varphi\\rangle,\\tag{12}\n",
    "$$\n",
    "\n",
    "复习一下，泡利 $Z$ 算符的矩阵形式为：\n",
    "\n",
    "$$\n",
    "Z := \\begin{bmatrix} 1 &0 \\\\ 0 &-1 \\end{bmatrix},\\tag{13}\n",
    "$$\n",
    "\n",
    "继续我们前面的 2 量子比特的例子，测量过后我们得到的期望值就是：\n",
    "$$\n",
    "\\langle Z \\rangle = \n",
    "\\langle \\varphi |Z\\otimes I| \\varphi\\rangle = \n",
    "\\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "-1-i \\quad\n",
    "-1-i \\quad\n",
    "0   \\quad\n",
    "0\n",
    "\\end{bmatrix}\n",
    "\\begin{bmatrix}\n",
    "1  &0 &0  &0 \\\\ \n",
    "0  &1 &0  &0 \\\\\n",
    "0  &0 &-1 &0 \\\\\n",
    "0  &0 &0  &-1 \n",
    "\\end{bmatrix}\n",
    "\\cdot\n",
    "\\frac{1}{2}\n",
    "\\begin{bmatrix}\n",
    "-1+i \\\\\n",
    "-1+i \\\\\n",
    "0 \\\\\n",
    "0\n",
    "\\end{bmatrix}\n",
    "= 1,\\tag{14}\n",
    "$$\n",
    "\n",
    "好奇的读者或许会问，这个测量结果好像就是我们原来的标签 1 ，这是不是意味着我们已经成功的分类这个数据点了？其实并不然，因为 $\\langle Z \\rangle$ 的取值范围通常在 $[-1,1]$之间。 为了对应我们的标签范围 $y^{k} \\in \\{0,1\\}$, 我们还需要将区间上下限映射上。这个映射最简单的做法就是让\n",
    "\n",
    "$$\n",
    "\\tilde{y}^{k} = \\frac{\\langle Z \\rangle}{2} + \\frac{1}{2} + bias \\quad \\in [0, 1].\\tag{15}\n",
    "$$\n",
    "\n",
    "其中加入偏置（bias）是机器学习中的一个小技巧，目的就是为了让决策边界不受制于原点或者一些超平面。一般我们默认偏置初始化为0，并且优化器在迭代过程中会类似于参数 $\\theta$ 一样不断更新偏置确保 $\\tilde{y}^{k} \\in [0, 1]$。当然读者也可以选择其他复杂的映射（激活函数）比如说 sigmoid 函数。映射过后我们就可以把 $\\tilde{y}^{k}$ 看作是我们估计出的标签（label）了。如果 $\\tilde{y}^{k}< 0.5$ 就对应标签 0，如果 $\\tilde{y}^{k}> 0.5$  就对应标签 1。 我们稍微复习一下整个流程，\n",
    "\n",
    "\n",
    "$$\n",
    "x^{k} \\rightarrow |\\psi\\rangle^{k} \\rightarrow U(\\boldsymbol{\\theta})|\\psi\\rangle^{k} \\rightarrow\n",
    "|\\varphi\\rangle^{k} \\rightarrow ^{k}\\langle \\varphi |Z\\otimes I\\cdots \\otimes I| \\varphi\\rangle^{k}\n",
    "\\rightarrow \\langle Z \\rangle  \\rightarrow \\tilde{y}^{k}.\\tag{16}\n",
    "$$\n",
    "\n",
    "最后我们就可以把损失函数定义为平方损失函数：\n",
    "\n",
    "$$\n",
    "\\mathcal{L} = \\sum_{k} |y^{k} - \\tilde{y}^{k}|^2.\\tag{17}\n",
    "$$\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:07.667491Z",
     "start_time": "2021-03-02T09:15:07.661325Z"
    }
   },
   "outputs": [],
   "source": [
    "# 生成只作用在第一个量子比特上的泡利 Z 算符\n",
    "# 其余量子比特上都作用单位矩阵\n",
    "def Observable(n):\n",
    "    \"\"\"\n",
    "    :param n: 量子比特数量\n",
    "    :return: 局部可观测量: Z \\otimes I \\otimes ...\\otimes I\n",
    "    \"\"\"\n",
    "    Ob = pauli_str_to_matrix([[1.0, 'z0']], n)\n",
    "    return Ob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:08.373511Z",
     "start_time": "2021-03-02T09:15:08.358729Z"
    }
   },
   "outputs": [],
   "source": [
    "# 搭建整个优化流程图\n",
    "class Net(paddle.nn.Layer):\n",
    "    \"\"\"\n",
    "    创建模型训练网络\n",
    "    \"\"\"\n",
    "    def __init__(self,\n",
    "                 n,      # 量子比特数量\n",
    "                 depth,  # 电路深度\n",
    "                 seed_paras=1,\n",
    "                 dtype='float64'):\n",
    "        super(Net, self).__init__()\n",
    "\n",
    "        self.n = n\n",
    "        self.depth = depth\n",
    "        \n",
    "        # 初始化参数列表 theta，并用 [0, 2*pi] 的均匀分布来填充初始值\n",
    "        self.theta = self.create_parameter(\n",
    "            shape=[n, depth + 3],\n",
    "            default_initializer=paddle.nn.initializer.Uniform(low=0.0, high=2*PI),\n",
    "            dtype=dtype,\n",
    "            is_bias=False)\n",
    "        \n",
    "        # 初始化偏置 (bias)\n",
    "        self.bias = self.create_parameter(\n",
    "            shape=[1],\n",
    "            default_initializer=paddle.nn.initializer.Normal(std=0.01),\n",
    "            dtype=dtype,\n",
    "            is_bias=False)\n",
    "\n",
    "    # 定义前向传播机制、计算损失函数 和交叉验证正确率\n",
    "    def forward(self, state_in, label):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            state_in: The input quantum state, shape [-1, 1, 2^n]\n",
    "            label: label for the input state, shape [-1, 1]\n",
    "        Returns:\n",
    "            The loss:\n",
    "                L = ((<Z> + 1)/2 + bias - label)^2\n",
    "        \"\"\"\n",
    "        # 将 Numpy array 转换成 tensor\n",
    "        Ob = paddle.to_tensor(Observable(self.n))\n",
    "        label_pp = paddle.to_tensor(label)\n",
    "\n",
    "        # 按照随机初始化的参数 theta \n",
    "        cir = U_theta(self.theta, n=self.n, depth=self.depth)\n",
    "        Utheta = cir.U\n",
    "        \n",
    "        # 因为 Utheta是学习到的，我们这里用行向量运算来提速而不会影响训练效果\n",
    "        state_out = matmul(state_in, Utheta)  # 维度 [-1, 1, 2 ** n]\n",
    "        \n",
    "        # 测量得到泡利 Z 算符的期望值 <Z>\n",
    "        E_Z = matmul(matmul(state_out, Ob), transpose(paddle.conj(state_out), perm=[0, 2, 1]))\n",
    "        \n",
    "        # 映射 <Z> 处理成标签的估计值 \n",
    "        state_predict = paddle.real(E_Z)[:, 0] * 0.5 + 0.5 + self.bias\n",
    "        loss = paddle.mean((state_predict - label_pp) ** 2)\n",
    "        \n",
    "        # 计算交叉验证正确率\n",
    "        is_correct = (paddle.abs(state_predict - label_pp) < 0.5).nonzero().shape[0]\n",
    "        acc = is_correct / label.shape[0]\n",
    "\n",
    "        return loss, acc, state_predict.numpy(), cir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练效果与调参\n",
    "\n",
    "好了， 那么定义完以上所有的概念之后我们不妨来看看实际的训练效果！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:08.911819Z",
     "start_time": "2021-03-02T09:15:08.887770Z"
    }
   },
   "outputs": [],
   "source": [
    "def heatmap_plot(net, N):\n",
    "    # 生成数据点 x_y_\n",
    "    Num_points = 30\n",
    "    x_y_ = []\n",
    "    for row_y in np.linspace(0.9, -0.9, Num_points):\n",
    "        row = []\n",
    "        for row_x in np.linspace(-0.9, 0.9, Num_points):\n",
    "            row.append([row_x, row_y])\n",
    "        x_y_.append(row)\n",
    "    x_y_ = np.array(x_y_).reshape(-1, 2).astype(\"float64\")\n",
    "\n",
    "    # 计算预测: heat_data\n",
    "    input_state_test = paddle.to_tensor(\n",
    "        datapoints_transform_to_state(x_y_, N))\n",
    "    loss_useless, acc_useless, state_predict, cir = net(state_in=input_state_test, label=x_y_[:, 0])\n",
    "    heat_data = state_predict.reshape(Num_points, Num_points)\n",
    "\n",
    "    # 画图\n",
    "    fig = plt.figure(1)\n",
    "    ax = fig.add_subplot(111)\n",
    "    x_label = np.linspace(-0.9, 0.9, 3)\n",
    "    y_label = np.linspace(0.9, -0.9, 3)\n",
    "    ax.set_xticks([0, Num_points // 2, Num_points - 1])\n",
    "    ax.set_xticklabels(x_label)\n",
    "    ax.set_yticks([0, Num_points // 2, Num_points - 1])\n",
    "    ax.set_yticklabels(y_label)\n",
    "    im = ax.imshow(heat_data, cmap=plt.cm.RdBu)\n",
    "    plt.colorbar(im)\n",
    "    plt.show()\n",
    "\n",
    "def QClassifier(Ntrain, Ntest, gap, N, D, EPOCH, LR, BATCH, seed_paras, seed_data,):\n",
    "    \"\"\"\n",
    "    量子二分类器\n",
    "    \"\"\"\n",
    "    # 生成数据集\n",
    "    train_x, train_y, test_x, test_y = circle_data_point_generator(Ntrain=Ntrain, Ntest=Ntest, boundary_gap=gap, seed_data=seed_data)\n",
    "\n",
    "    # 读取训练集的维度\n",
    "    N_train = train_x.shape[0]\n",
    "    \n",
    "    paddle.seed(seed_paras)\n",
    "    # 定义优化图\n",
    "    net = Net(n=N, depth=D)\n",
    "\n",
    "    # 一般来说，我们利用Adam优化器来获得相对好的收敛\n",
    "    # 当然你可以改成SGD或者是RMSprop\n",
    "    opt = paddle.optimizer.Adam(learning_rate=LR, parameters=net.parameters())\n",
    "\n",
    "    # 初始化寄存器存储正确率 acc 等信息\n",
    "    summary_iter, summary_test_acc = [], []\n",
    "\n",
    "    # 优化循环\n",
    "    for ep in range(EPOCH):\n",
    "        for itr in range(N_train // BATCH):\n",
    "\n",
    "            # 将经典数据编码成量子态 |psi>, 维度 [-1, 2 ** N]\n",
    "            input_state = paddle.to_tensor(datapoints_transform_to_state(train_x[itr * BATCH:(itr + 1) * BATCH], N))\n",
    "\n",
    "            # 前向传播计算损失函数\n",
    "            loss, train_acc, state_predict_useless, cir \\\n",
    "                = net(state_in=input_state, label=train_y[itr * BATCH:(itr + 1) * BATCH])\n",
    "            if itr % 50 == 0:\n",
    "\n",
    "                # 计算测试集上的正确率 test_acc\n",
    "                input_state_test = paddle.to_tensor(datapoints_transform_to_state(test_x, N))\n",
    "                loss_useless, test_acc, state_predict_useless, t_cir \\\n",
    "                    = net(state_in=input_state_test,label=test_y)\n",
    "                print(\"epoch:\", ep, \"iter:\", itr,\n",
    "                      \"loss: %.4f\" % loss.numpy(),\n",
    "                      \"train acc: %.4f\" % train_acc,\n",
    "                      \"test acc: %.4f\" % test_acc)\n",
    "                # 存储正确率 acc 等信息\n",
    "                summary_iter.append(itr + ep * N_train)\n",
    "                summary_test_acc.append(test_acc) \n",
    "            if (itr + 1) % 151 == 0 and ep == EPOCH - 1:\n",
    "                print(\"训练后的电路：\")\n",
    "                print(cir)\n",
    "\n",
    "            # 反向传播极小化损失函数\n",
    "            loss.backward()\n",
    "            opt.minimize(loss)\n",
    "            opt.clear_grad()\n",
    "\n",
    "    # 画出 heatmap 表示的决策边界\n",
    "    heatmap_plot(net, N=N)\n",
    "\n",
    "    return summary_test_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上都是我们定义的函数，下面我么讲运行主程序。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-02T09:15:50.771171Z",
     "start_time": "2021-03-02T09:15:09.593720Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集的维度大小 x (200, 2) 和 y (200, 1)\n",
      "测试集的维度大小 x (100, 2) 和 y (100, 1) \n",
      "\n",
      "epoch: 0 iter: 0 loss: 0.0318 train acc: 1.0000 test acc: 0.5400\n",
      "epoch: 0 iter: 50 loss: 0.3359 train acc: 0.0000 test acc: 0.8200\n",
      "epoch: 0 iter: 100 loss: 0.0396 train acc: 1.0000 test acc: 0.8700\n",
      "epoch: 0 iter: 150 loss: 0.0952 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 1 iter: 0 loss: 0.1586 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 1 iter: 50 loss: 0.1534 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 1 iter: 100 loss: 0.0624 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 1 iter: 150 loss: 0.0883 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 2 iter: 0 loss: 0.1627 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 2 iter: 50 loss: 0.1378 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 2 iter: 100 loss: 0.0669 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 2 iter: 150 loss: 0.0860 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 3 iter: 0 loss: 0.1658 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 3 iter: 50 loss: 0.1359 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 3 iter: 100 loss: 0.0671 train acc: 1.0000 test acc: 1.0000\n",
      "epoch: 3 iter: 150 loss: 0.0849 train acc: 1.0000 test acc: 1.0000\n",
      "训练后的电路：\n",
      "--Rz(0.542)----Ry(3.456)----Rz(2.699)----*--------------X----Ry(6.153)--\n",
      "                                         |              |               \n",
      "--Rz(3.514)----Ry(1.543)----Rz(2.499)----X----*---------|----Ry(3.050)--\n",
      "                                              |         |               \n",
      "--Rz(5.947)----Ry(3.161)----Rz(3.897)---------X----*----|----Ry(1.583)--\n",
      "                                                   |    |               \n",
      "--Rz(0.718)----Ry(5.038)----Rz(1.348)--------------X----*----Ry(0.030)--\n",
      "                                                                        \n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATsAAAD5CAYAAABYi5LMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAfYElEQVR4nO3dW4wkZ3UH8P+pvs19Z70XG4wxhJiLg7BFjE1kocQgHONAEAQJw4OJA1klgjzkAUF4yAMRkiNHIiBusQghluLwAFlwIsfYIgpGEQgMGN8wjlmMsdf2MmvvbaZn+lInD1Wz6Wn3OfVVd8109fT/J7W8O19/VTVdtcd1OX2OqCqIiHa7aNwbQES0ExjsiGgqMNgR0VRgsCOiqcBgR0RTgcGOiKZCNe8EEbkGwKcAVAB8UVVv7BvfC+BLAF4GYB3An6jqA1nLrcwuaW3x4MCxKBJ/m5xx8aY6g+685B1DzXUXm7FO8d4w7DozDDt3lIQmd64zmJVG5Y16c73FZq4zdpbrjMXeWLftrjNutwb/fP0ktN0c5XBAtPQiRWc96L3aPP5NVb1mlPUVKVewE5EKgM8CeDOAJwD8QERuU9WHet72MQD3quo7ROSV6fvflLXs2uJBvPjdfzdwbGau5s6tN+xfo1q3T16rtYo5FlX8k95K1R53x5xIWMkI6t74sGNZhp3bdf6xjjLXG2t1Yne57a49HjtzO+2uvc6Njr/ODXvuRtMOWs3TG/bYiWfcda4ee3zgz9d+dIs7L0hnHdVX/GHQW9v3/tP+0VdYnLyXsZcDeFRVj6hqC8BXALy97z0XA/gWAKjqwwBeIiLnjrylRDR+IpCoEvQqm7zB7nwAv+r5+xPpz3r9BMA7AUBELgdwIYAXDbuBRFQmgqhaD3qVTd57doOuafqvK24E8CkRuRfA/QB+DGDgub6IHAJwCACqCwdybgoR7bj0zG4S5Q12TwC4oOfvLwJwtPcNqnoKwA0AICIC4Bfp63lU9WYANwPAzMHf5Jd0iUpOAEhlOoLdDwBcJCIvBfAkgOsAvLf3DSKyDGAtvaf3AQB3pwGQiCadCKJpOLNT1Y6IfAjAN5GknnxJVR8UkT9Lx78A4FUAbhGRLoCHALw/ZNm1RhXnv+ycgWMvPjDvzj241HDGZsyxpRn7159zntQCQMN5WttwnsbOOGO1yL+FWqvYT0a9uRmLdXlPjz3dEarptLv23LbzNNZ72goAG84T1zXnieuZlj12uuU/jT12yn6q+tSJpjn2+PE1c+zEsSV3nStLgx+Cth78qjsv1LRcxkJVbwdwe9/PvtDz5+8CuGj0TSOi0pmie3ZENMUEgqjq572WFYMdEYXjmR0RTQsGOyLa/USmJvWEiKaYYHLP7FjiiYjCSYRKtR70ClqcyDUi8jMReVREPjpgfI+I/LuI/EREHhSRG0Ln9ivNmd1co4LXXLg8cOziF/p5Rect2Hl2++fsD32Pk2c36+TDAcBM1c4/qzv5cHWnikg1q+qJNxzb+V7ijCH2c9NGK9bkcJL/NLKf9mlk77N2xq/ScvLwNpzcvqaz4LWO//k8c8bOs3vaGTvi5NndN+s/DbWqtBzLyB0NIsWd2QVWUfoggIdU9W0icgDAz0TkXwB0A+ZuwTM7IgomKLTqSUgVJQWwmH71dAHAs0i+ax8yd4vSnNkR0WQo8J7doCpKV/S95zMAbkPyHfxFAO9W1VhEQuZuwWBHROHy5dntF5F7ev5+c1r84+zSBszpvy/w+wDuBfBGJNXP7xKR7wTO3YLBjohyyBXsVlT1Mmc8s4oSkgpKN2pS//5REfkFgFcGzt2CwY6IgokIolphhTkzqygBeBxJW4fvpBXPXwHgCIATAXO3YLAjonAFfl0ssIrS3wD4sojcj+TS9SOqupJsyvPneusrTbBrVCNcdO7iwLGLzvFLPL1g0U492dOwHzgv1O2dVtfBHZo2RRur5phs2GkD0rLL+kjXXydadlcn7ThzO3ZjF43t8kUAAlJTDBl1pdx/MM4XzaPGrDlWr9rHAQDM1uy52lgwx+K5wcclAGzATwNZco6/c5wUEq+EWNMpOQUAx54dfPxVMppIhSoyqTigitJRAFeHzvWUJtgR0WTIam1aVgx2RBRMRNw+zWXGYEdEuRR1ObzTGOyIKJyAZ3ZEtPslVU8Y7Iho1xNEQzZhGrfSBLtaJcJ5RgqJl1oCAAfm7EfhC5Fd8SNaXbHHmifddcqaPR6fOWGOdVftrpLatNNZACBet1Na1EtLadtpKZrRkUuHTD2RjNSTqO6kbDipJ1K3u8VF8351nMhJIYn27LPHFgZ3vUuWudddZ8Np/u51bmt37TSZlTU7lQgA7jOqAEVu2ZxAvIwlomnBYEdEu54IUHFqOZYZgx0R5SK8Z0dEu52I8BsURDQdeM+OiKYCg92IqpGYzXG8yhGAn15SOfWMPbb2rDnWWXnKXWf83DFzrOulnpw5Y461TtmpJQDQXrXTSzrrdvOWuG1/PnHLacaD7Uw9sQ+9qGaP1eadqidLc+46G8tO6smSnXpS2XeePXbghe46ofbnt3fhXHOs2bE/g/1zfqWVcxYG/zuqZuyTIALm2RHR7icQRBmd98qKwY6IwsnklniazBBNRGMjIkGvwGVlNcn+sIjcm74eEJGuiJyTjj0mIvenY/c8f+lb8cyOiIIlhQAKWlZAk2xVvQnATen73wbgL1W192b7VZtl2rMw2BFRuGIvY882ugYAEdlsdP2Q8f73APjXYVfGy1giykEQVaKgV4BBja7PH7hWkTkA1wD4Ws+PFcCdIvJDETmUtTKe2RFRMMl3ZldEk+xNbwPwP32XsFeq6lEROYikefbDqnq3tTGlCXbVSMxuSwu1jJytpp0v5+bSPf24PXbcz7PrPGvfJmget8s4bZw4bY5l5dl13Dw7u+xPu+nk2bUzuosNSTLKCXl5eLVZL8/OLvdVX8zIs9tr59nNHrDLazWc8lnI6M5Wqdg9Vmt1u2vevNMJba/TlQwAlo08vEpBl585koqLaJK96Tr0XcKmncegqsdE5DCSy2Iz2PEyloiCiSRBM+QV4GyTbBGpIwlotz1/nbIHwO8C+EbPz+ZFZHHzz0jaLT7graw0Z3ZENBmKOkMMbJINAO8AcKeq9p5+nwvgcJriUgVwq6re4a2PwY6IggmCz9qCZDXJTv/+ZQBf7vvZEQCX5FkXgx0RBRMB6vy6GBHtdiLJw8RJxGBHRMEExd2z22mlCXYCoGGkKzQiK/Um4XUC63qlmJyx1q/tMQBoHjthj/36OXNs44Rd4mn9pF2mCQA2TtnjnXU7vaTjpJ5o7H+2cUb3MUtWUqmXvlB1Uk8aS3bqSWPJSREB0Fn3uqwNl4LjdTsDgGjR7j4Wze4xx+bm7bSURac8FgAszGxj6okUe89uJ5Um2BFR+SVndrxnR0RTgGd2RLTrRSJ8GktE06HCsuxEtNttfl1sEjHYEVEuDHYjEhHUjdQT2bDTNQAgatkVK9qnT5hj3ZPHzbF1p3IJ4KeXrB2zU2Gaz9npEevOGAC0Vu3UidYZu+qJl5YSZ6SedP1hU0bRE1RrFXtsxj4s26vO7+mk2ABAtzVceklUs6uM1JfsYwgAYucYi/YcNMdmluwPcK5uf3YAMGuMF3H1yaRiIpoKAj6gIKIpwHt2RDQV+HUxIpoOE3xmN5kX30Q0Fpv17AqqVDxq31h3bj8GOyLKpahg19M39i0ALgbwHhG5uPc9qnqTql6qqpcC+CsA31bVZ0Pm9ivNZWwksFNPOn41kLhpp57Eq17zGzulxWuMkzXXSy9ZW7Gb6mycslNLknH7c1ht29VJmk7lknZGaklXh8s9ycqyn3G2d3bDTiGZ89JohqzQAmQ1+bErm2ycWHCXW9lnH3+Vjn2cSNc+FrIehjaMN0QF5J5ExRbvHKVvbN65PLMjohyKbbgzSt/Y4LmbSnNmR0TlJ5A8343dzr6xeeYCYLAjopxyXA5vZ9/YPHMB8DKWiHIQJF8FDHkFGLpvbOjcXjyzI6JwAkQl6BtrzfXWx2BHRMEEQK3AsuzD9o215noY7Igo2OZl7CQqVbCzSsd4OUcAoOtOnp0z1l6185xap+x8OMDvBOblw3m5dM0Tfomnk05u2rpTqsnLs2uNqcRT07kUWo/tM4eu2nl2cspfacUpK1Vf8Mo42ceCdwwBfg4oNprmkHS8PDu7wxpg58EVUmBYpLDL2J1WqmBHROUmKCY5eRwY7IgoF17GEtGuJwLUMhqglxWDHREF42UsEU0NXsYS0a4nEJ7ZjcrL35GO3VEKAOIN+/F/vOalDdiP/jtrfkqB1+Vq2LEzHb9E0aqTQuLNbTr5I9mpJ9tT4qnupC8Mne6y5h8nXtcyrztbd91OA/HGAEBb9nHkjUnX3p6K2CWnAKBu3FMrJERNcKXi0gQ7Iiq/5J7duLdiOAx2RBSs6K+L7SQGOyIKJ8CEZp4w2BFROKaeENGUyFWpuFQY7IgoGM/siiDOhxjblS4AQJ0KEd22043KGeu2uu46O+tOeknTXu6Gs1wvRSRr3B8bR9WTrOXa/2C8pFVvbCbjMeG805ms44x5+7Oz7ne+UyctSr2UKrX3WdY9M+vfkRQQpJKvi01msJvQW41ENC4iYa+wZWU3uhaR30ubZD8oIt/u+fljInJ/OnbPoLm9ynNmR0QTISomPbm3SfabkTTQ+YGI3KaqD/W8ZxnA5wBco6qPi8jBvsVcpaorYdtNRBRIUOiZ3dlG16raArDZ6LrXewH8m6o+DgCqemzYbWewI6JcIgl7BQhpdP1yAHtF5L9F5Icicn3PmAK4M/35oayV8TKWiMLluB+HYppkVwH8NoA3AZgF8F0R+Z6qPgLgSlU9ml7a3iUiD6vq3dbGMNgRUTDJl2dXRJPsJ9LlrAJYFZG7AVwC4BFVPQokl7YichjJZXH5g53btch5DA/4j/C99JK4NXzqSbdlb5M310v1aGdUGPHmbscYsH0NdzxeGk09GuF3cfbLsGPeMQQA6qT9aNupmOKkW0UZH+52f52rwEIAZxtdA3gSSaPr9/a95xsAPiMiVQB1AFcA+KSIzAOIVPV0+uerAXzcW1lpgh0RTYaiYl1Ik2xV/amI3AHgPgAxgC+q6gMi8hsADqe5g1UAt6rqHd76GOyIKFjR36AIbJJ9E4Cb+n52BMnlbDAGOyLKZUK/LcZgR0T5TGq+GoMdEQUTlmUnomnBy1gi2vUEvIzdVpKRZ4fYz4mzaGwvN85IMIu9/Ckn38vr1pWV0+bPHW65WevM+ORtmb+LNzbc75K1rbGzX7x95h0nXh5d1txJVUSpqHGYiGBHRCUR/r3X0mGwI6Jg7jedSo7Bjohy4WUsEe16bJJNRFNjQmMdgx0R5SET210sd8pMVoMMSXw6Hb9PRF5bzKYOR7ux+SKinAJLspcxHuY6swtpkAHgLQAuSl9XAPh8+l8imnCiChkyr3Xc8p7ZhTTIeDuAWzTxPQDLIvKCAraViEpANA56lU3eYBfSICPkPUQ0kTSpHB7yCjBi39jMub3yPqAIaZAR8p7kjUlHoEMAcMEFFwx6CxGVTUb7gFCj9I0NvKW2Rd4zu9AGGVnvAQCo6s2qepmqXrZv//6cm0JEO04LPbMbpW9syNwt8ga7sw0yRKSOpEHGbX3vuQ3A9elT2dcDOKmqT+VcDxGVVIH37EbpG5v7dlmuy9iQBhlI6slfC+BRAGsAbsizjoHrlYyYHFXMIXFaLUlkj2V1cIrc5dpzvTZ0Wd859Od6Y8VcduSR/bt4Y8N9Rln/546c/eLtM+848Y6vrLmTSd3OZ322rW9s4NznLSiXrAYZqqoAPph3uUQ0ARTBDx+wvX1jg2+Xbdpt/9shom2lQByHvbKF3Bb7BoA3iEhVROaQ5Oz+NHDuFvy6GBHlUlQO3Sh9YwFg0FxvfQx2RJRPgQnDw/aNteZ6GOyIKJzq0G0Qxo3BjohyKeNXwUIw2BFRDlroZexOKk2wUzidozLy7KRaM8eimv0rRnV7LCt/qlK3xyt1O++v7uRz1TLq4nhzvbGuDl9vJ6v7mCUrz27Y32XYMcDPjfT2mTfmHUNARp5nrW6OebmlcUbrtm2vXsZgR0S7nvLMjoimgID37IhoKijQ5dNYItrt8n1drFQY7IgoF17GEtEU4AOK0SkQWxVQo4zH+1X7EX7FSz1xxmqz/jqrM3a6ize3sWqnMcxm5Ay0nRQSP73ESbnIKP/UHbIqrVemCfDTRGadvJVZJ5XDmwcA1Rl7v/ipRM6YcwwBgDRm7DEnZUqdYz4rtcT6d6QFVRhmsCOi3Y9fFyOi6aDQTnvcGzEUBjsiCqfgmR0R7X4KhTLPjoh2PUVoFeLSYVl2IsohfUAR8gqQ1eg6bZB9Mm2Sfa+I/HXP2GMicn/683v65/YrzZmdAugY/8NQ5xE94D/ej+bmzLHa/Kw5Vp2zl5nM3XDGnLSUpt2ZaaHtHyDDVyCxJ7bi8aSeDJteMu/Mm2v4h3NjyU5Raiw1zLHavH0sVJ1jCACk7qSeOGNasY+hbsaB0DJyUwpJPNHiHlDkaHT9HVV9q7GYq1R1JWR9pQl2RDQJFFrcA4qzja4BQEQ2G133B7tC8DKWiMJtPo0t5jI2tNH174jIT0TkP0Xkt/q25s60efahrJXxzI6IctA8DyiKaJL9IwAXquoZEbkWwNcBXJSOXamqR0XkIIC7RORhVb3b2hgGOyIKp8iTejJyk2xVPdXz59tF5HMisl9VV1T1aPrzYyJyGMllsRnseBlLRDkU+jQ2s9G1iJwnkjztEpHLkcSs4yIyLyKL6c/nAVwN4AFvZTyzI6JwBT6NDWmSDeBdAP5cRDoAmgCuU1UVkXMBHE7jYBXArap6h7e+0gQ7VaBtpEFoxU4ZAIBoZn6osfqSnZbijQHAzNq6OdZZt9NLuq3hn2RFp+x0l3rbTsmYiex7LFnpLMOmj2ZdMgxb9cRLL5nZa6ePAH56iTdWX3SOE2cMAKJZ+/hDw05b0ZqdltJZ9/dKy8jhKqboSbGFALKaZKvqZwB8ZsC8IwAuybOu0gQ7IpoA/G4sEU0DhUIn9OtiDHZEFI5ndkQ0FVSh7da4t2IoDHZElEOupOJSYbAjonx4GUtEu54WWghgR5Um2CmAlpH0pbWMMjpzC+ZYtLhsjjWWT5hjrVNr7jrjlp1LFzvtnyInh6xStzuPAX7XsvoZO9Fz0cn762xbWSl/vFqzf1evC9iwZZoAYG6/fRzN7Ft0xvaYY/XlJXed0cKyOaY1O0evK/ZnsNGx8y0BYM3I5TS79+XEp7FEtPupQrN6OZYUgx0RBVNVxG37SqHMGOyIKJyCZ3ZENB0Y7Iho11NVxGylSETTgE9jR6SqdurJgp1aAgDasB//R3v22WOrp82x+fOGr9klTnes6swZZ8zfHa1Ve5s6TteyUUpOaUb3MYs4JZwAP83G+xzqC3bXrazUEy+9ZO7gXmds2RyrLJ3jrtM7/uKGXf6pabXaA7CWkS7UNNKiCkk9meCnsaxUTETBNp/GhrxCjNg31p3brzRndkQ0Gbyk+TxG6RubY+5ZPLMjonBp6knIK8DZvrGq2gKw2Td2W+Yy2BFRuPSeXUHBbpS+saFzz+JlLBEFU+R6GrudfWND5m7BYEdE4VTdIhh9tq1vbMjcfqUJdrHaj9TXOv5mLszaVSkqywfNsarTEk4i/wp/sW6nQNTm7c5Q9SU73aCxvOqus73qdDRzup156SVZqSfxkGVPvOouQFbqif3ZVt3P1u/01Vh2Uk8OLNvr3H+eOVY5+CJ3nVjcbw6pc9yutezP/eSGH2xOrg0+rrtDphFtoUBcXJ7d2b6xAJ5E0jf2vb1vEJHzADyTtk882zcWwImsuf1KE+yIqPwUxeXZjdI3FsDAud76GOyIKJwCWuDXxYbtG2vN9TDYEVEObKVIRNOAJZ6IaBqoKrrhT2NLhcGOiHLgZezIOnGMlbXBzXf3zfmbOTtnV6xAbP9fqCJ2eonU7RQHAJB5u9LK3MJxc2zmgF1ppX3arogCAB0v9WTdblzsfSk762bzsN+DjJzKLwAgFTv1pDJjN9Xx0npqc/4+qywt22NOdZLKPjv1BMvnuuuMFw6YY2tiV2k507bTop5r+hV5TlipJ8N2T+rFy1gimgoKaBFBcwwY7IgomEILq3qy0xjsiCicDl/QddwY7IgomGr2VwzLisGOiMKp8p4dEU2HYYtDjBuDHRGFY+rJ6Nqx4tjq4Fyxg/N+16ia08lqed7Oc9KqvdyobpdiAoDqHnu5aNq5dPGaPVZr+iWetGXn2elG0x5zSllpnHH/ZdgE0owSWRLZeXZStUs8yay9X2TG32eRkxspC8vmWNywS0PF83Z+HgCsVWbNseNrdv7j06ftvMlnzmy461w5PXi8U8CDBQUQ8wEFEe16qnxAQUS7nzKpmIimwgQHO3YXI6Ickm9QhLxChDa6FpHXiUhXRN7V87PHROT+tHn2PdbcTTyzI6JwBX6DIrTRdfq+v0VSgr3fVaq6ErI+ntkRUTBFkmcX8goQ2uj6LwB8DcCxUba9NGd2rW6Mx58bnD6xUPc3sx3bKSTNtj13ob5gjs0u2ekGANCI7J0prTV7rGOnDUjHTi0BgMhJIfFKWYkzBh1TzpRTXksrdokndVJWtOaXeOrW7DQQrdudyTbUXueZtv/5nXTKMXnpJT9/zj6GjhzzS4GtnRp8HBXyBX5VxMU9jR3U6PqK3jeIyPkA3gHgjQBe1781AO4UEQXwD309aZ+nNMGOiMpPNdc3KIpokv33AD6iql2R5739SlU9KiIHAdwlIg+r6t3WxjDYEVEuOSoVj9wkG8BlAL6SBrr9AK4VkY6qfl1VjwKAqh4TkcNILosZ7IioABp8Py5EZpNsVX3p5p9F5MsA/kNVvy4i8wAiVT2d/vlqAB/3VsZgR0ThCsyzC2ySbTkXwOH0jK8K4FZVvcNbH4MdEQVTFFsIIKtJdt/P/7jnz0cAXJJnXQx2RBROFd0Wq56MpNnq4oEnTg4cy6rWsLJmpxzsn7PTGPY0nLSUhp1uAABzNXu8XrFTYWqRva3V+rK7TmexqDiVX5yhgY/Dts7Nesdgsfr7zBv1drd3LLQzjpNO2x5fX7f/Aa+N0Olrxej0BfjVS/73abs6zsNPnnLXeXJlcNpKtzN6kFLN3rdlVZpgR0SToctgR0S7nQKY0DoADHZElA/P7Iho14sVaLFSMRFNA17GEtGup1Bexo5qo9XFz3/53MCxZ0761UD2zjvpJXN285ZlZ2w2o9LKXN1OPZl1xuoVu9pHreKneXhpIF7TIS8tJct2pZ50nUsh7yrJ+4fWzjjlWO/Y1TqaTiWP0+t21ZiTa3blEgA4fsYetxrjAHblEsBOLdl04qknB/6846TQhOIDCiKaGgx2RLTrqfJpLBFNAQWfxhLRFOA9OyKaGryMJaJdL7lnN+6tGE6u7mKS+HTa4/E+EXmt8b43isiPROQBEflnEWFQJdoluqpBr7LJG4TeAuCi9HUFgM/j+d2AIgD/DOBNqvqIiHwcwPsA/KO34FZzA7968NGBY7UZuwsYAFTqdteoStXOeYuqdqzPyi9zmmNBvHJLzpg3L8uAZiRBvO3ZTvGQN7nV+UeU1c/UW2fcsce6TrHKTsvp3Aag2xrcMQ8A2ut2l7D2ql3GaeP0s+46N04NbqPaXV9154VQAEVWsxORawB8Ckml4i+q6o3G+14H4HsA3q2qX80zd1PevrFvB3CLJr4HYFlEXtD3nn0ANlT1kfTvdwH4o5zrIaISUihacdgrS0+T7LcAuBjAe0TkYuN9W5pkh87tlTfYDerzeH7fe1YA1ERks6vQu7C1gxARTajkaWxhl7GjNMkOnXtW3mCX2edRk+uM6wB8UkS+D+A0gIHn+iJySETuEZF74g2/8S8RlUD6gCLkhbRvbM/rUN/SMk+eeppk9/elCDnx2iLznp2IfBDAn6Z//QGy+zxCVb8L4A3p/KsBvHzQstOGuTcDQO2cC8t3R5OIttg8swuU1Td2lCbZIXO3yAx2qvpZJNfGEJE/APAhEfkKkgcTJ1X1qf45InIwbVzbAPARAJ/IWg8RTYYCU0+GbpIdOHeLvE9jbwdwLYBHAawBuGFzQERuB/CBtEv3h0XkrUgukz+vqv+Vcz1EVEIxCv262ChNsqtZc/uJ9yh/J4nIrwH8Mv3rfiQPOmhycJ+VU+9+uVBVD4yyMBG5I11miBVVvSZjedciuVTdbJL9CatJdk+w+6o1111XWYJdLxG5J+Nan0qG+6ycuF/+X96nsUREE4nBjoimQlmD3c3j3gDKjfusnLhfUqW8Z0dEVLSyntkRERVqrMGOJaMmj4hcIyI/S/fZRweMB+1TKk7APtkrIofT/fF9EXn1OLZz3MZ9ZtdbMuoQkpJRW/SUjLpOVV+NJBfvfTu5kZQIrDSRuU+pOIH75GMA7lXV1wC4HklZpKkz7mDHklGTJaTSRMg+peKE7JOLAXwLAFT1YQAvEZFzd3Yzx2/cwY4loyZLyP7KXY2CRhLyef8EwDsBQEQuB3Ahku+STpVxB7tCS0bRtgupNJG7GgWNJOTzvhHAXhG5F0ltuB9jCv8N7fiN/u0sGUXbLqTSRO5qFDSSzM9bVU8hLdohSfmQX6SvqbLjZ3aq+llVvVRVLwXwdQDXp0/wXg+nZFT6382SUf2F/GhnnK1SISJ1JGfct/W95zYE7FMqTOY+EZHldAwAPgDg7jQATpVxp3CwZNQEUdWOiHwISS+AzUoTD/ZVqTD3KRUvcJ+8CsAtItIF8BCA949tg8eI36Agoqkw7gcUREQ7gsGOiKYCgx0RTQUGOyKaCgx2RDQVGOyIaCow2BHRVGCwI6Kp8H9PccXvAbZjYgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "主程序段总共运行了 35.94043445587158 秒\n"
     ]
    }
   ],
   "source": [
    "def main():\n",
    "    \"\"\"\n",
    "    主函数\n",
    "    \"\"\"\n",
    "    time_start = time.time()\n",
    "    acc = QClassifier(\n",
    "        Ntrain = 200,        # 规定训练集大小\n",
    "        Ntest = 100,         # 规定测试集大小\n",
    "        gap = 0.5,           # 设定决策边界的宽度\n",
    "        N = 4,               # 所需的量子比特数量\n",
    "        D = 1,               # 采用的电路深度\n",
    "        EPOCH = 4,           # 训练 epoch 轮数\n",
    "        LR = 0.01,           # 设置学习速率\n",
    "        BATCH = 1,           # 训练时 batch 的大小\n",
    "        seed_paras = 19,     # 设置随机种子用以初始化各种参数\n",
    "        seed_data = 2,       # 固定生成数据集所需要的随机种子\n",
    "    )\n",
    "    \n",
    "    time_span = time.time() - time_start\n",
    "    print('主程序段总共运行了', time_span, '秒')\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过打印训练结果可以看到不断优化后分类器在测试集和数据集的正确率都达到了 $100\\%$。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 参考文献\n",
    "\n",
    "[1] Mitarai, Kosuke, et al. Quantum circuit learning. [Physical Review A 98.3 (2018): 032309.](https://arxiv.org/abs/1803.00745)\n",
    "\n",
    "[2] Farhi, Edward, and Hartmut Neven. Classification with quantum neural networks on near term processors. [arXiv preprint arXiv:1802.06002 (2018).](https://arxiv.org/abs/1802.06002)\n",
    "\n",
    "[3] [Schuld, Maria, et al. Circuit-centric quantum classifiers. [Physical Review A 101.3 (2020): 032308.](https://arxiv.org/abs/1804.00633)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
